2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "moco/Service/TFShapeInferenceRule.h"
19 #include <moco/Support/TFShapeInferenceHelper.h>
21 #include "moco/IR/TFDialect.h"
22 #include "moco/IR/TFNode.h"
24 #include <loco/IR/NodeShape.h>
25 #include <loco/Service/ShapeInference.h>
27 #include <oops/UserExn.h>
35 class ShapeInferenceAlgorithm final : public moco::TFNodeVisitor<loco::NodeShape>
38 ShapeInferenceAlgorithm(const loco::ShapeInferenceRule::Context *ctx) : _ctx{ctx}
44 const loco::ShapeInferenceRule::Context *_ctx;
47 bool shape_known(const loco::Node *node) const { return _ctx->known(node); }
48 loco::NodeShape node_shape(const loco::Node *node) const { return _ctx->get(node); }
51 loco::NodeShape binary_node_shape(const moco::TFNode::Node *node)
53 // This helper works only for binary node.
54 assert(node->arity() == 2);
56 auto lhs_shape = node_shape(node->arg(0));
57 auto rhs_shape = node_shape(node->arg(1));
59 loco::TensorShape lhs_tensorshape = lhs_shape.as<loco::TensorShape>();
60 loco::TensorShape rhs_tensorshape = rhs_shape.as<loco::TensorShape>();
61 loco::TensorShape sum_tensorshape = moco::broadcast_shape(lhs_tensorshape, rhs_tensorshape);
63 loco::NodeShape sum_shape({sum_tensorshape});
68 loco::NodeShape node_shape_with_check(const moco::TFNode::Node *node)
70 auto nodeshape = node_shape(node);
71 assert(nodeshape.domain() == loco::Domain::Tensor);
76 bool valid_scalar_value(moco::TFConst *node)
78 auto nodeshape = node_shape(node);
79 if (nodeshape.domain() != loco::Domain::Tensor)
83 if (node->dtype() != loco::DataType::S32)
88 auto tensor_shape = nodeshape.as<loco::TensorShape>();
89 if (!(tensor_shape.rank() == 0 || tensor_shape.rank() == 1))
97 int32_t scalar_value(moco::TFConst *node)
99 auto nodeshape = node_shape(node);
100 assert(node->dtype() == loco::DataType::S32);
102 auto tensor_shape = nodeshape.as<loco::TensorShape>();
103 assert(tensor_shape.rank() == 0 || tensor_shape.rank() == 1);
105 return node->at<loco::DataType::S32>(0);
109 loco::NodeShape visit(const moco::TFAdd *node) final { return binary_node_shape(node); }
111 loco::NodeShape visit(const moco::TFAvgPool *node) final
113 auto value_shape = node_shape(node->value());
114 assert(value_shape.domain() != loco::Domain::Unknown);
116 moco::PlaneInference infer_plane_shape;
118 infer_plane_shape.padding(node->padding());
119 infer_plane_shape.stride(moco::stride_of(node->strides(), node->data_layout()));
120 infer_plane_shape.window(moco::window_of(node->ksize(), node->data_layout()));
122 auto input_feature_shape = moco::as_feature_shape(value_shape, node->data_layout());
123 auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
124 auto output_feature_shape = input_feature_shape;
125 auto output_plane_shape = infer_plane_shape(input_plane_shape);
127 moco::update(output_feature_shape).with(output_plane_shape);
129 return moco::as_tensor_shape(output_feature_shape, node->data_layout());
132 loco::NodeShape visit(const moco::TFBiasAdd *node) final
134 return node_shape_with_check(node->value());
137 loco::NodeShape visit(const moco::TFConcatV2 *node) final
139 // axis shape should be available
140 auto axis_node = node->axis();
141 auto axis_shape = node_shape(axis_node);
142 assert(axis_shape.domain() != loco::Domain::Unknown);
144 // check all input shapes and all ranks should be same
145 auto value_a = node->values(0);
146 auto value_a_shape = node_shape(value_a);
147 assert(value_a_shape.domain() == loco::Domain::Tensor);
148 auto value_a_tensor_shape = value_a_shape.as<loco::TensorShape>();
149 uint32_t a_rank = value_a_tensor_shape.rank();
151 uint32_t num_values = node->num_values();
152 for (uint32_t ni = 1; ni < num_values; ++ni)
154 auto value_b = node->values(ni);
155 auto value_b_shape = node_shape(value_b);
156 assert(value_b_shape.domain() == loco::Domain::Tensor);
157 auto value_b_tensor_shape = value_b_shape.as<loco::TensorShape>();
158 assert(a_rank == value_b_tensor_shape.rank());
161 int32_t axis_value = 0;
162 bool axis_available = false;
164 // check for axis is TFConst
165 auto tfconst = dynamic_cast<moco::TFConst *>(axis_node);
166 if (tfconst != nullptr)
168 if (valid_scalar_value(tfconst))
170 axis_value = scalar_value(tfconst);
171 axis_available = true;
177 // TODO may need to refine error message
178 throw oops::UserExn("ConcatV2 node does not have axis input", node->name());
181 uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)a_rank + axis_value;
182 loco::TensorShape output_tensor_shape = value_a_tensor_shape;
184 for (uint32_t index = 0; index < a_rank; ++index)
186 if (value_a_tensor_shape.dim(index).known())
188 uint32_t dim = value_a_tensor_shape.dim(index).value();
189 uint32_t dim_acc = dim;
191 for (uint32_t ni = 1; ni < num_values; ++ni)
193 auto value_b = node->values(ni);
194 auto value_b_shape = node_shape(value_b);
195 assert(value_b_shape.domain() == loco::Domain::Tensor);
196 auto value_b_tensor_shape = value_b_shape.as<loco::TensorShape>();
197 assert(value_b_tensor_shape.dim(index).known());
198 if (index == axis_absolute)
199 dim_acc += value_b_tensor_shape.dim(index).value();
201 assert(dim == value_b_tensor_shape.dim(index).value());
203 output_tensor_shape.dim(index) = dim_acc;
206 output_tensor_shape.dim(index).unset();
208 return loco::NodeShape(output_tensor_shape);
211 loco::NodeShape visit(const moco::TFConst *node) final
213 loco::TensorShape output_tensor_shape;
215 uint32_t rank = node->rank();
216 output_tensor_shape.rank(rank);
217 for (uint32_t index = 0; index < rank; ++index)
219 if (node->dim(index).known())
220 output_tensor_shape.dim(index) = node->dim(index).value();
222 output_tensor_shape.dim(index).unset();
225 return loco::NodeShape(output_tensor_shape);
228 loco::NodeShape visit(const moco::TFConv2D *node) final
230 auto input_shape = moco::node_shape(node->input());
231 auto ker_shape = moco::node_shape(node->filter());
232 auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWIO
233 auto node_stride = moco::stride_of(node->strides(), node->data_layout());
234 auto node_window = moco::window_of(ker_tensor_shape, "HWIO");
236 moco::PlaneInference infer_plane_shape;
238 infer_plane_shape.padding(node->padding());
239 infer_plane_shape.stride(node_stride);
240 infer_plane_shape.window(node_window);
242 auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
243 auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
244 // output count is from input count, depth is from kernel 'O' which is dim(3)
245 auto output_feature_shape = input_feature_shape;
246 output_feature_shape.depth() = ker_tensor_shape.dim(3).value();
248 auto output_plane_shape = infer_plane_shape(input_plane_shape);
250 moco::update(output_feature_shape).with(output_plane_shape);
252 return moco::as_tensor_shape(output_feature_shape, node->data_layout());
255 loco::NodeShape visit(const moco::TFConv2DBackpropInput *node) final
257 // TFConv2DBackpropInput's first input, named 'input_sizes', actually contains shape of node
258 // output's feature map. We can get shape of TFConv2DBackpropInput by just copying this.
259 // TODO Support when 'input_sizes' is not TFConst, or support constant folding
260 auto input_sizes_node = dynamic_cast<moco::TFConst *>(node->input_sizes());
261 if (input_sizes_node == nullptr)
263 // we are now supporting somekind of constant folding for this node, wait till it is finished
264 loco::NodeShape unknown;
268 // Let's support S32 for time being
269 // TODO Support other integer types
270 assert(input_sizes_node->dtype() == loco::DataType::S32);
271 assert(input_sizes_node->size<loco::DataType::S32>() == 4);
274 loco::TensorShape ofm_tensor_shape;
275 ofm_tensor_shape.rank(4);
276 for (uint32_t i = 0; i < 4; ++i)
278 int32_t dim = input_sizes_node->at<loco::DataType::S32>(i);
280 ofm_tensor_shape.dim(i) = (uint32_t)dim;
283 return loco::NodeShape(ofm_tensor_shape);
286 loco::NodeShape visit(const moco::TFDepthwiseConv2dNative *node) final
288 auto input_shape = moco::node_shape(node->input()); // NHWC
289 auto ker_shape = moco::node_shape(node->filter());
290 auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWCM
291 auto node_stride = moco::stride_of(node->strides(), node->data_layout());
292 auto node_window = moco::window_of(ker_tensor_shape, "HWCM");
294 moco::PlaneInference infer_plane_shape;
296 infer_plane_shape.padding(node->padding());
297 infer_plane_shape.stride(node_stride);
298 infer_plane_shape.window(node_window);
300 auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
301 auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
302 // output count is from input count, depth is from kernel 'CM' which is dim(2) * dim(3)
303 auto output_feature_shape = input_feature_shape;
304 output_feature_shape.depth() =
305 loco::Dimension(ker_tensor_shape.dim(2).value() * ker_tensor_shape.dim(3).value());
307 auto output_plane_shape = infer_plane_shape(input_plane_shape);
309 moco::update(output_feature_shape).with(output_plane_shape);
311 return moco::as_tensor_shape(output_feature_shape, node->data_layout());
314 loco::NodeShape visit(const moco::TFFakeQuantWithMinMaxVars *node) final
316 return node_shape_with_check(node->inputs());
319 loco::NodeShape visit(const moco::TFFusedBatchNorm *node) final
321 return node_shape_with_check(node->x());
324 loco::NodeShape visit(const moco::TFIdentity *node) final
326 return node_shape_with_check(node->input());
329 loco::NodeShape visit(const moco::TFMaximum *node) final { return binary_node_shape(node); }
331 loco::NodeShape visit(const moco::TFMaxPool *node) final
333 auto input_shape = node_shape(node->input());
334 assert(input_shape.domain() != loco::Domain::Unknown);
336 moco::PlaneInference infer_plane_shape;
338 infer_plane_shape.padding(node->padding());
339 infer_plane_shape.stride(moco::stride_of(node->strides(), node->data_layout()));
340 infer_plane_shape.window(moco::window_of(node->ksize(), node->data_layout()));
342 auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
343 auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
344 auto output_feature_shape = input_feature_shape;
345 auto output_plane_shape = infer_plane_shape(input_plane_shape);
347 moco::update(output_feature_shape).with(output_plane_shape);
349 return moco::as_tensor_shape(output_feature_shape, node->data_layout());
352 loco::NodeShape visit(const moco::TFMean *node) final
354 auto input_shape = node_shape(node->input());
355 auto reduction_indices = node->reduction_indices();
357 // Get constant values if reduction_indices is const
358 std::vector<int32_t> reduction_values;
359 if (auto tfconst = dynamic_cast<moco::TFConst *>(reduction_indices))
361 assert(tfconst->dtype() == loco::DataType::S32);
362 auto const_size = tfconst->size<loco::DataType::S32>();
363 for (uint32_t i = 0; i < const_size; ++i)
365 int32_t axis = tfconst->at<loco::DataType::S32>(i);
367 axis += input_shape.as<loco::TensorShape>().rank();
368 reduction_values.push_back(axis);
373 // we cannot find a valid reduction indices value
374 loco::NodeShape unknown;
378 loco::TensorShape output_shape;
379 auto input_tensor_shape = input_shape.as<loco::TensorShape>();
381 if (node->keep_dims())
383 output_shape.rank(input_tensor_shape.rank());
384 for (uint32_t i = 0; i < input_tensor_shape.rank(); ++i)
385 output_shape.dim(i) = input_tensor_shape.dim(i);
386 for (uint32_t i = 0; i < reduction_values.size(); ++i)
387 output_shape.dim(reduction_values.at(i)) = 1;
391 std::vector<bool> check_reduce(input_tensor_shape.rank(), false);
392 for (uint32_t i = 0; i < reduction_values.size(); ++i)
393 check_reduce.at(reduction_values.at(i)) = true;
395 uint32_t reduce_cnt = 0;
396 for (uint32_t i = 0; i < check_reduce.size(); ++i)
397 if (check_reduce.at(i))
400 output_shape.rank(input_tensor_shape.rank() - reduce_cnt);
401 for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
402 if (check_reduce.at(i) == false)
403 output_shape.dim(j++) = i;
406 return loco::NodeShape(output_shape);
409 loco::NodeShape visit(const moco::TFMul *node) final { return binary_node_shape(node); }
411 loco::NodeShape visit(const moco::TFPack *node) final
413 loco::NodeShape unknown;
415 auto input_shape_0 = node_shape(node->values(0));
416 if (input_shape_0.domain() != loco::Domain::Tensor)
418 // TODO fix this for other cases
419 // We support only valid tensor shape for now
422 loco::TensorShape tensor_shape_0 = input_shape_0.as<loco::TensorShape>();
424 // all input shapes should be same
425 auto num_values = node->N();
426 for (uint32_t i = 1; i < num_values; ++i)
428 auto input_shape = node_shape(node->values(i));
429 if (input_shape.domain() != loco::Domain::Tensor)
435 loco::TensorShape tensor_shape = input_shape.as<loco::TensorShape>();
436 if (!(input_shape_0 == input_shape))
438 throw oops::UserExn("All input values shape should be same", node->name());
442 // output rank will be +1 of rank of the input
443 // axis should be in range of [-r, r), where r is rank of the output
444 auto axis = node->axis();
445 int32_t rank = static_cast<int32_t>(tensor_shape_0.rank());
447 int32_t rank_output = rank + 1;
448 if (axis < -rank_output || rank_output <= axis)
450 throw oops::UserExn("axis is out of range", node->name());
453 auto axis_stack = (axis >= 0) ? axis : rank_output + axis;
455 loco::TensorShape output_tensor_shape;
457 output_tensor_shape.rank(rank_output);
458 for (int32_t r = 0; r < axis_stack; ++r)
460 output_tensor_shape.dim(r).set(tensor_shape_0.dim(r).value());
462 output_tensor_shape.dim(axis_stack).set(num_values);
463 for (int32_t r = axis_stack; r < rank; ++r)
465 output_tensor_shape.dim(r + 1).set(tensor_shape_0.dim(r).value());
468 return loco::NodeShape(output_tensor_shape);
471 loco::NodeShape visit(const moco::TFPad *node) final
473 auto input_shape = node_shape(node->input());
474 assert(input_shape.domain() == loco::Domain::Tensor);
476 auto const_paddings = loco::must_cast<moco::TFConst *>(node->paddings());
477 assert(const_paddings->dtype() == loco::DataType::S32);
478 assert(const_paddings->rank() == 2);
480 loco::TensorShape input_tensor_shape = input_shape.as<loco::TensorShape>();
481 loco::TensorShape output_tensor_shape;
483 output_tensor_shape.rank(input_tensor_shape.rank());
484 for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
486 output_tensor_shape.dim(axis) = input_tensor_shape.dim(axis).value() +
487 const_paddings->at<loco::DataType::S32>(axis * 2) +
488 const_paddings->at<loco::DataType::S32>(axis * 2 + 1);
491 return loco::NodeShape{output_tensor_shape};
494 loco::NodeShape visit(const moco::TFPlaceholder *node) final
496 loco::TensorShape output_tensor_shape;
498 uint32_t rank = node->rank();
499 output_tensor_shape.rank(rank);
500 for (uint32_t index = 0; index < rank; ++index)
502 if (node->dim(index).known())
503 output_tensor_shape.dim(index) = node->dim(index).value();
505 output_tensor_shape.dim(index).unset();
508 return loco::NodeShape(output_tensor_shape);
511 loco::NodeShape visit(const moco::TFRealDiv *node) final { return binary_node_shape(node); }
513 loco::NodeShape visit(const moco::TFRelu *node) final
515 return node_shape_with_check(node->features());
518 loco::NodeShape visit(const moco::TFRelu6 *node) final
520 return node_shape_with_check(node->features());
523 loco::NodeShape visit(const moco::TFReshape *node) final
525 loco::NodeShape unknown;
527 // For now, we only consider Fixed Reshape, i.e. Reshape with determined
528 // 'shape' input. So here we only support case when 'shape' input of
529 // TFReshape is TFConst. If 'shape' input is not TFConst, another
530 // transform (e.g. constant folding) should be done beforehand to make
532 // TODO Support dynamic Reshape
533 // Note that 'shape()' here is 'shape' input, not node's shape information
534 auto const_shape_input = dynamic_cast<moco::TFConst *>(node->shape());
535 if (!const_shape_input)
537 // 'shape' input of TFReshape is not TFConst, we can not do shape inference
541 // 'Shape' input should be integer tensor of rank 1, e.g. [2, 3, 4] or [3, -1]
542 assert(const_shape_input->dtype() == loco::DataType::S32);
543 assert(const_shape_input->rank() == 1);
545 auto shape_rank = const_shape_input->dim(0).value();
546 assert(shape_rank > 0);
548 loco::TensorShape output_shape;
549 output_shape.rank(shape_rank);
550 for (uint32_t axis = 0; axis < shape_rank; ++axis)
552 auto shape_dim = const_shape_input->at<loco::DataType::S32>(axis);
555 // Reshape's new shape has wildcard dimension, i.e. dynamic reshape
558 assert(shape_dim >= 1);
559 output_shape.dim(axis) = shape_dim;
562 // TODO Compare 'tensor' input and validate coherency?
563 // Not sure this is appropriate stage for this task.
565 return loco::NodeShape(output_shape);
568 loco::NodeShape visit(const moco::TFRsqrt *node) final
570 return node_shape_with_check(node->x());
573 loco::NodeShape visit(const moco::TFShape *node) final
575 auto input_shape = node_shape(node->input());
576 auto input_tensor_shape = input_shape.as<loco::TensorShape>();
578 loco::TensorShape output_shape;
580 // Note that input shape becomes node(TFShape)'s value
581 output_shape.rank(1);
582 output_shape.dim(0) = input_tensor_shape.rank();
584 return loco::NodeShape(output_shape);
587 loco::NodeShape visit(const moco::TFSoftmax *node) final
589 return node_shape_with_check(node->logits());
592 loco::NodeShape visit(const moco::TFSqrt *node) final { return node_shape_with_check(node->x()); }
594 loco::NodeShape visit(const moco::TFSquaredDifference *node) final
596 return binary_node_shape(node);
599 loco::NodeShape visit(const moco::TFSqueeze *node) final
601 auto input_shape = node_shape(node->input());
603 // TODO Not sure Squeeze only get input as Tensor
604 // Note that tensor_shape() has assertion in it
605 auto input_tensor_shape = input_shape.as<loco::TensorShape>();
607 auto squeeze_dims_vec = node->squeeze_dims();
608 std::set<int64_t> squeeze_dims(squeeze_dims_vec.cbegin(), squeeze_dims_vec.cend());
610 loco::TensorShape output_shape;
611 uint32_t output_rank = 0;
613 if (squeeze_dims.empty())
615 // Remove all dimensions whose value is 1
616 for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
618 assert(input_tensor_shape.dim(axis).known());
619 auto dim = input_tensor_shape.dim(axis).value();
623 output_shape.rank(++output_rank);
624 output_shape.dim(output_rank - 1) = dim;
630 uint32_t input_rank = input_tensor_shape.rank();
632 // Sanity check for 'squeeze_dims'
633 auto is_valid_squeeze_dims = [&squeeze_dims, &input_rank]() {
634 if (!(squeeze_dims.size() < input_rank))
636 for (auto squeeze_dim : squeeze_dims)
638 if (!(squeeze_dim >= -(int64_t)input_rank))
640 if (!(squeeze_dim < (int64_t)input_rank))
646 if (!is_valid_squeeze_dims())
648 throw oops::UserExn("Invalid squeeze dimension", node->name());
651 // Resolve negative squeeze dimension
652 std::set<int64_t> resolved_squeeze_dims;
653 for (auto squeeze_dim : squeeze_dims)
656 resolved_squeeze_dims.insert(squeeze_dim + (int64_t)input_rank);
658 resolved_squeeze_dims.insert(squeeze_dim);
661 // Remove squeeze dimensions only
662 for (uint32_t axis = 0; axis < input_rank; ++axis)
664 assert(input_tensor_shape.dim(axis).known());
665 auto dim = input_tensor_shape.dim(axis).value();
666 if (resolved_squeeze_dims.find((int64_t)axis) == resolved_squeeze_dims.cend())
669 output_shape.rank(++output_rank);
670 output_shape.dim(output_rank - 1) = dim;
681 assert(output_shape.rank() > 0);
683 return loco::NodeShape(output_shape);
686 loco::NodeShape visit(const moco::TFStopGradient *node) final
688 return node_shape_with_check(node->input());
691 loco::NodeShape visit(const moco::TFStridedSlice *node) final
693 loco::NodeShape unknown;
694 auto input_shape = node_shape(node->input());
695 if (input_shape.domain() != loco::Domain::Tensor)
697 // TODO fix this for other cases
698 // We support only tensor shape for now
702 // TODO support full mask features: see import codes also
703 // Limited attributes for now
704 assert(node->begin_mask() == 0);
705 assert(node->end_mask() == 0);
706 assert(node->ellipsis_mask() == 0);
707 assert(node->shrink_axis_mask() == 1);
709 auto const_begin = loco::must_cast<moco::TFConst *>(node->begin());
710 auto const_end = loco::must_cast<moco::TFConst *>(node->end());
711 auto const_strides = loco::must_cast<moco::TFConst *>(node->strides());
713 assert(dynamic_cast<moco::TFConst *>(node->input()) != nullptr);
714 assert(const_begin != nullptr);
715 assert(const_end != nullptr);
716 assert(const_strides != nullptr);
718 auto input_tensor_shape = input_shape.as<loco::TensorShape>();
719 auto input_rank = input_tensor_shape.rank();
720 auto output_rank = input_rank;
722 // TODO support strides with > 1
723 uint32_t elements = const_strides->size<loco::DataType::S32>();
724 for (uint32_t e = 0; e < elements; ++e)
725 assert(const_strides->at<loco::DataType::S32>(e) == 1);
727 // lets apply begin ~ end range from input shape
728 loco::TensorShape output_shape_range;
730 output_shape_range.rank(input_rank);
731 for (uint32_t r = 0; r < input_rank; ++r)
733 // TODO apply begin/end mask
734 // TODO apply ellipsis mask
735 // TODO apply strides
736 auto end = const_end->at<loco::DataType::S32>(r);
737 auto begin = const_begin->at<loco::DataType::S32>(r);
738 auto size = end - begin;
739 output_shape_range.dim(r).set(size);
742 // get final tensor shape from applying shrink mask to output_shape_range
743 loco::TensorShape output_tensor_shape;
745 if (node->shrink_axis_mask() != 0)
747 for (uint32_t rs = 0; rs < input_rank; ++rs)
749 int32_t bit = 1 << rs;
750 int32_t mask = node->shrink_axis_mask();
753 // shrink one dimension
754 assert(output_rank > 0);
755 output_rank = output_rank - 1;
758 output_tensor_shape.rank(output_rank);
759 for (uint32_t rs = 0, rd = 0; rs < input_rank; ++rs)
761 int32_t bit = 1 << rs;
762 int32_t mask = node->shrink_axis_mask();
763 if ((bit & mask) == 0)
765 // use this dimension
766 output_tensor_shape.dim(rd).set(output_shape_range.dim(rs).value());
769 // else this dimension is shrink-ed
774 output_tensor_shape = output_shape_range;
777 return loco::NodeShape(output_tensor_shape);
780 loco::NodeShape visit(const moco::TFSub *node) final { return binary_node_shape(node); }
782 loco::NodeShape visit(const moco::TFTanh *node) final { return node_shape_with_check(node->x()); }
785 loco::NodeShape visit(const moco::TFPush *node) { return node_shape_with_check(node->from()); }
788 loco::NodeShape visit(const moco::TFNode *) final
790 loco::NodeShape unknown;
802 struct Context final : public loco::ShapeInferenceRule::Context
804 bool known(const loco::Node *node) const final { return loco::shape_known(node); }
805 loco::NodeShape get(const loco::Node *node) const final { return loco::shape_get(node); }
808 class Sink final : public loco::ShapeInferenceRule::Sink
819 const Status &status(void) const { return _status; }
820 const loco::NodeShape &shape(void) const { return _shape; }
823 void okay(const loco::NodeShape &shape) final
829 void fail(void) final
836 Status _status = Unknown;
837 loco::NodeShape _shape;
840 } // namespace compat
846 bool TFShapeInferenceRule::support(const API &api) const
848 return api == API::V1 or api == API::V2;
851 bool TFShapeInferenceRule::recognize(const loco::Dialect *d) const
853 // handle only TensorFlow dialect
854 return TFDialect::get() == d;
857 bool TFShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
859 ::compat::Context ctx;
862 infer(&ctx, node, &sink);
864 assert(sink.status() == ::compat::Sink::Okay or sink.status() == ::compat::Sink::Fail);
866 if (sink.status() == ::compat::Sink::Fail)
871 shape = sink.shape();
876 void TFShapeInferenceRule::infer(const Context *ctx, const loco::Node *node, Sink *sink) const
878 assert(node->dialect() == TFDialect::get());
879 assert(dynamic_cast<const TFNode *>(node) != nullptr);
881 ShapeInferenceAlgorithm alg{ctx};
882 auto shape = loco::must_cast<const TFNode *>(node)->accept(&alg);
884 if (shape.domain() == loco::Domain::Unknown)