2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Service/CircleShapeInferenceRule.h"
20 #include "ShapeInfer_StridedSlice.h"
22 #include <luci/IR/CircleNodes.h>
23 #include <luci/IR/CircleDialect.h>
24 #include <luci/IR/CircleNodeVisitor.h>
27 #include <oops/InternalExn.h>
37 std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape)
40 for (uint32_t r = 0; r < tensor_shape.rank(); ++r)
44 os << tensor_shape.dim(r).value();
50 loco::TensorShape own_shape(const luci::CircleNode *node)
52 loco::TensorShape shape;
53 shape.rank(node->rank());
54 for (uint32_t r = 0; r < node->rank(); ++r)
55 shape.dim(r) = loco::Dimension(node->dim(r).value());
59 loco::NodeShape use_own(const luci::CircleNode *node)
61 loco::TensorShape shape = own_shape(node);
62 return loco::NodeShape{shape};
66 * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
70 * auto expanded_tensor_shape = expand(tensor_shape).to(N);
72 class TensorShapeExpander
75 TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
81 loco::TensorShape to(uint32_t output_rank)
83 auto const &input_shape = _shape;
84 uint32_t const input_rank = input_shape.rank();
86 assert(input_rank <= output_rank && "Cannot shrink rank");
87 uint32_t const axis_shift = output_rank - input_rank;
89 loco::TensorShape output_shape;
91 output_shape.rank(output_rank);
92 for (uint32_t axis = 0; axis < output_rank; ++axis)
94 output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
101 const loco::TensorShape _shape;
105 * @breif Expand shape x and y to same rank by align right and filling with 1
107 void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
109 auto x_rank = x.rank();
110 auto y_rank = y.rank();
112 if (x_rank == y_rank)
115 TensorShapeExpander x_exp(x);
116 TensorShapeExpander y_exp(y);
118 auto xy_rank = std::max(x_rank, y_rank);
120 x = x_rank > y_rank ? x : x_exp.to(xy_rank);
121 y = y_rank > x_rank ? y : y_exp.to(xy_rank);
125 * @breif Returns shape of expanded dimension of input x and y having same rank
127 loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
129 assert(x.rank() == y.rank());
131 auto rank = x.rank();
133 loco::TensorShape output_shape;
135 output_shape.rank(rank);
136 for (uint32_t axis = 0; axis < rank; ++axis)
138 assert(x.dim(axis).known() && y.dim(axis).known());
140 auto x_dim = x.dim(axis).value();
141 auto y_dim = y.dim(axis).value();
143 // each dimension of x and y should be same or one must be 1 if different
144 if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
145 INTERNAL_EXN("Cannot produce expand_dimension of two shapes");
147 output_shape.dim(axis) = std::max(x_dim, y_dim);
153 loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
158 expand_rank(x_match, y_match);
160 auto output_shape = expand_dimension(x_match, y_match);
166 * @brief vector_from_constant will return int64_t vector from CircleConst node
168 template <loco::DataType T> std::vector<int64_t> vector_from_constant(luci::CircleConst *const_node)
170 std::vector<int64_t> result;
172 for (uint32_t idx = 0; idx < const_node->size<T>(); ++idx)
173 result.push_back(const_node->at<T>(idx));
178 template <class CIRCLENODE> loco::NodeShape broadcast_xy(const CIRCLENODE *node)
180 auto x_shape = loco::shape_get(node->x()).template as<loco::TensorShape>();
181 auto y_shape = loco::shape_get(node->y()).template as<loco::TensorShape>();
183 auto output_shape = broadcast_shape(x_shape, y_shape);
185 return loco::NodeShape{output_shape};
188 template <class CIRCLENODE> loco::NodeShape use_x(const CIRCLENODE *node)
190 auto x_shape = loco::shape_get(node->x()).template as<loco::TensorShape>();
191 return loco::NodeShape{x_shape};
194 template <class CIRCLENODE> loco::NodeShape use_logits(const CIRCLENODE *node)
196 auto shape = loco::shape_get(node->logits()).template as<loco::TensorShape>();
197 return loco::NodeShape{shape};
200 template <class CIRCLENODE>
201 loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *paddings)
203 const loco::DataType S32 = loco::DataType::S32;
205 auto input_shape = loco::shape_get(node->input()).template as<loco::TensorShape>();
207 // TODO support other data type
208 LUCI_ASSERT(paddings->dtype() == S32, "Only support int 32 for now");
209 LUCI_ASSERT(paddings->rank() == 2, "paddings should be rank 2")
211 int32_t n = paddings->dim(0).value();
212 int32_t v = paddings->dim(1).value();
214 LUCI_ASSERT(v == 2, "paddings should be [n, 2]");
215 LUCI_ASSERT(n == int32_t(input_shape.rank()),
216 "paddings [n, 2] should have same value of input rank");
218 loco::TensorShape output_shape;
220 output_shape.rank(input_shape.rank());
221 for (int32_t ni = 0; ni < n; ++ni)
223 int32_t idx = ni * 2;
224 int value = input_shape.dim(ni).value();
225 value += paddings->at<S32>(idx + 0); // left
226 value += paddings->at<S32>(idx + 1); // right
227 output_shape.dim(ni) = value;
230 return loco::NodeShape{output_shape};
233 loco::NodeShape infer_add_n(const luci::CircleAddN *node)
235 auto shape = loco::shape_get(node->inputs(0)).as<loco::TensorShape>();
237 for (uint32_t idx = 1; idx < node->arity(); ++idx)
239 auto shape_idx = loco::shape_get(node->inputs(idx)).as<loco::TensorShape>();
240 if (!(shape == shape_idx))
242 INTERNAL_EXN_V("ADD_N shape not same as the first input: ", idx);
245 return loco::NodeShape{shape};
248 loco::NodeShape infer_arg_max(const luci::CircleArgMax *node)
250 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
251 auto dimension_shape = loco::shape_get(node->dimension()).as<loco::TensorShape>();
253 int64_t select_axis = 0;
255 LUCI_ASSERT(node->dimension(), "2nd input dimension() should not be nullptr");
257 // Only support node's shape() is CircleConst with S32/S64
258 // Support S32 for now.
259 auto const_shape_node = loco::must_cast<luci::CircleConst *>(node->dimension());
260 LUCI_ASSERT(const_shape_node->dtype() == loco::DataType::S32,
261 "Only support int32 CircleConst for CircleArgMax");
263 if (const_shape_node->rank() > 1)
264 INTERNAL_EXN_V("Only support rank 0/1 CircleConst",
265 oops::to_uint32(const_shape_node->rank()));
267 select_axis = const_shape_node->scalar<loco::DataType::S32>();
269 assert(select_axis < input_shape.rank());
270 assert(select_axis >= 0); // TODO support minus of this breaks
272 // NOTE select_axis is removed
273 loco::TensorShape shape_output;
274 uint32_t rank = input_shape.rank();
275 uint32_t shrink = static_cast<uint32_t>(select_axis);
277 shape_output.rank(rank - 1);
278 for (uint32_t r = 0, d = 0; r < rank; ++r)
282 shape_output.dim(d++) = input_shape.dim(r);
284 return loco::NodeShape{shape_output};
287 loco::NodeShape infer_arg_min(const luci::CircleArgMin *node)
289 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
290 auto dimension_shape = loco::shape_get(node->dimension()).as<loco::TensorShape>();
292 int64_t select_axis = 0;
294 LUCI_ASSERT(node->dimension(), "2nd input dimension() should not be nullptr");
296 // Only support node's shape() is CircleConst with S32/S64
297 // Support S32 for now.
298 auto const_shape_node = loco::must_cast<luci::CircleConst *>(node->dimension());
299 LUCI_ASSERT(const_shape_node->dtype() == loco::DataType::S32,
300 "Only support int32 CircleConst for CircleArgMin");
302 if (const_shape_node->rank() > 1)
303 INTERNAL_EXN_V("Only support rank 0/1 CircleConst",
304 oops::to_uint32(const_shape_node->rank()));
306 select_axis = const_shape_node->scalar<loco::DataType::S32>();
308 assert(select_axis < input_shape.rank());
309 assert(select_axis >= 0); // TODO support minus of this breaks
311 // NOTE select_axis is removed
312 loco::TensorShape shape_output;
313 uint32_t rank = input_shape.rank();
314 uint32_t shrink = static_cast<uint32_t>(select_axis);
316 shape_output.rank(rank - 1);
317 for (uint32_t r = 0, d = 0; r < rank; ++r)
321 shape_output.dim(d++) = input_shape.dim(r);
323 return loco::NodeShape{shape_output};
326 // Call this for CircleAvgPool2D and CircleMaxPool2D only
327 template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node)
329 LUCI_ASSERT(loco::shape_known(node->value()), "Shape must be known");
331 auto ifm_shape = loco::shape_get(node->value()).template as<loco::TensorShape>();
332 assert(ifm_shape.rank() == 4);
334 uint32_t input_height = ifm_shape.dim(1).value();
335 uint32_t input_width = ifm_shape.dim(2).value();
336 uint32_t stride_height = node->stride()->h();
337 uint32_t stride_width = node->stride()->w();
338 uint32_t window_height = node->filter()->h();
339 uint32_t window_width = node->filter()->w();
340 uint32_t dilation_height = 1; // dilation for CircleAvgPool2D and CircleMaxPool2D is 1
341 uint32_t dilation_width = 1;
342 uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
343 uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
345 uint32_t output_height = 0;
346 uint32_t output_width = 0;
348 if (node->padding() == luci::Padding::VALID)
350 output_height = (input_height + stride_height - effective_window_height) / stride_height;
351 output_width = (input_width + stride_width - effective_window_width) / stride_width;
353 else if (node->padding() == luci::Padding::SAME)
355 output_height = (input_height + stride_height - 1) / stride_height;
356 output_width = (input_width + stride_width - 1) / stride_width;
359 LUCI_ASSERT(false, "Wrong padding type");
361 loco::TensorShape ofm_shape;
363 ofm_shape.dim(0) = ifm_shape.dim(0);
364 ofm_shape.dim(1) = output_height;
365 ofm_shape.dim(2) = output_width;
366 ofm_shape.dim(3) = ifm_shape.dim(3);
368 return loco::NodeShape{ofm_shape};
371 loco::NodeShape infer_batch_to_space_nd(const luci::CircleBatchToSpaceND *node)
373 const loco::DataType S32 = loco::DataType::S32;
375 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
376 // Support only input rank is 3 and 4
377 assert(input_shape.rank() == 3 || input_shape.rank() == 4);
379 // Only support block_shape() with S32 type CircleConst for now
380 auto const_block_shape = loco::must_cast<luci::CircleConst *>(node->block_shape());
381 LUCI_ASSERT(const_block_shape->dtype() == loco::DataType::S32, "Only support int32 block_shape");
383 // Only support crops() with S32 type CircleConst for now
384 auto const_crops = loco::must_cast<luci::CircleConst *>(node->crops());
385 LUCI_ASSERT(const_crops->dtype() == loco::DataType::S32, "Only support int32 crops");
387 auto const_block_shape_shape = loco::shape_get(const_block_shape).as<loco::TensorShape>();
388 auto const_crops_shape = loco::shape_get(const_crops).as<loco::TensorShape>();
389 assert(const_block_shape_shape.rank() == 1);
390 assert(const_crops_shape.rank() == 2);
392 int32_t input_spatial_dim = input_shape.rank() - 2;
393 assert(const_block_shape_shape.dim(0) == input_spatial_dim);
394 assert(const_crops_shape.dim(0) == input_spatial_dim);
395 assert(const_crops_shape.dim(1) == 2);
397 loco::TensorShape shape_output;
399 shape_output.rank(input_shape.rank());
401 int32_t output_batch_size = input_shape.dim(0).value();
402 for (int32_t dim = 0; dim < input_spatial_dim; ++dim)
404 int dim_size = input_shape.dim(dim + 1).value() * const_block_shape->at<S32>(dim);
405 dim_size -= const_crops->at<S32>(dim * 2);
406 dim_size -= const_crops->at<S32>(dim * 2 + 1);
407 shape_output.dim(dim + 1) = dim_size;
409 assert(output_batch_size % const_block_shape->at<S32>(dim) == 0);
410 output_batch_size = output_batch_size / const_block_shape->at<S32>(dim);
412 shape_output.dim(0) = output_batch_size;
413 shape_output.dim(input_shape.rank() - 1) = input_shape.dim(input_shape.rank() - 1);
415 return loco::NodeShape{shape_output};
424 template <class Conv2DType> OutputSize infer_conv2d_type(const Conv2DType *node)
426 auto ifm_shape = loco::shape_get(node->input()).template as<loco::TensorShape>();
427 auto ker_shape = loco::shape_get(node->filter()).template as<loco::TensorShape>();
428 assert(ifm_shape.rank() == 4);
429 assert(ker_shape.rank() == 4);
431 uint32_t input_height = ifm_shape.dim(1).value();
432 uint32_t input_width = ifm_shape.dim(2).value();
433 uint32_t stride_height = node->stride()->h();
434 uint32_t stride_width = node->stride()->w();
435 uint32_t ker_height = ker_shape.dim(1).value();
436 uint32_t ker_width = ker_shape.dim(2).value();
437 uint32_t dilation_height = node->dilation()->h();
438 uint32_t dilation_width = node->dilation()->w();
439 uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
440 uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
442 uint32_t output_height = 0;
443 uint32_t output_width = 0;
445 if (node->padding() == luci::Padding::VALID)
447 output_height = (input_height + stride_height - effective_ker_height) / stride_height;
448 output_width = (input_width + stride_width - effective_ker_width) / stride_width;
450 else if (node->padding() == luci::Padding::SAME)
452 output_height = (input_height + stride_height - 1) / stride_height;
453 output_width = (input_width + stride_width - 1) / stride_width;
456 LUCI_ASSERT(false, "Wrong padding type");
458 OutputSize os{output_height, output_width};
463 // BatchMatMulV2 supports broadcasting in the batch dimensions(BatchMatMul doesn't)
464 // TODO Distinguish BatchMatMul and BatchMatMulV2
465 loco::NodeShape infer_batchmatmul_shape(const loco::TensorShape &x_shape,
466 const loco::TensorShape &y_shape, bool adj_x, bool adj_y)
468 uint32_t x_rank = x_shape.rank();
469 uint32_t y_rank = y_shape.rank();
470 assert(x_rank >= 2 && y_rank >= 2);
472 loco::TensorShape output_shape;
473 output_shape.rank(x_shape.rank());
474 // Braodcast in the batch dimension
475 if (x_rank > 2 || y_rank > 2)
477 loco::TensorShape dummy_x = x_shape;
478 loco::TensorShape dummy_y = y_shape;
479 expand_rank(dummy_x, dummy_y);
481 expand_rank(output_shape, dummy_y);
483 for (uint32_t d = 0; d < output_shape.rank() - 2; d++)
485 uint32_t max_dim = std::max(dummy_x.dim(d).value(), dummy_y.dim(d).value());
486 if (dummy_x.dim(d) == dummy_y.dim(d) ||
487 dummy_x.dim(d).value() * dummy_y.dim(d).value() == max_dim)
488 output_shape.dim(d).set(max_dim);
490 INTERNAL_EXN("BatchMatMul has wrong shape");
494 loco::Dimension x_lhs = adj_x ? x_shape.dim(x_rank - 1) : x_shape.dim(x_rank - 2);
495 loco::Dimension x_rhs = adj_x ? x_shape.dim(x_rank - 2) : x_shape.dim(x_rank - 1);
496 loco::Dimension y_lhs = adj_y ? y_shape.dim(y_rank - 1) : y_shape.dim(y_rank - 2);
497 loco::Dimension y_rhs = adj_y ? y_shape.dim(y_rank - 2) : y_shape.dim(y_rank - 1);
499 if (not(x_rhs == y_lhs))
500 INTERNAL_EXN("x_rhs and y_lhs should be same");
502 uint32_t out_rank = output_shape.rank();
503 output_shape.dim(out_rank - 2) = x_lhs;
504 output_shape.dim(out_rank - 1) = y_rhs;
506 return loco::NodeShape{output_shape};
509 loco::NodeShape infer_concatenation(const luci::CircleConcatenation *node)
511 // TODO Support when CircleConcatenation has 0 input
512 assert(node->numValues() > 0);
514 auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>();
515 auto axis = node->axis();
517 axis += first_shape.rank();
520 assert(first_shape.rank() > static_cast<uint32_t>(axis));
522 loco::TensorShape output_shape;
524 output_shape.rank(first_shape.rank());
525 for (uint32_t i = 0; i < output_shape.rank(); ++i)
526 output_shape.dim(i) = first_shape.dim(i);
528 for (uint32_t i = 1; i < node->numValues(); ++i)
530 auto input_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>();
532 for (uint32_t j = 0; j < output_shape.rank(); ++j)
534 if (j == static_cast<uint32_t>(axis))
535 output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value();
537 assert(output_shape.dim(j) == input_shape.dim(j));
541 return loco::NodeShape{output_shape};
544 loco::NodeShape infer_conv2d(const luci::CircleConv2D *node)
548 auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
549 auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
551 INFO(l) << "[luci] CircleConv2D ShapeInf ifm(" << ifm_shape.rank() << ") ker(" << ker_shape.rank()
554 assert(ifm_shape.rank() == 4);
555 assert(ker_shape.rank() == 4);
556 assert(ifm_shape.dim(3) == ker_shape.dim(3));
558 auto os = infer_conv2d_type(node);
560 loco::TensorShape ofm_shape;
562 ofm_shape.dim(0) = ifm_shape.dim(0);
563 ofm_shape.dim(1) = os.height;
564 ofm_shape.dim(2) = os.width;
565 ofm_shape.dim(3) = ker_shape.dim(0);
567 return loco::NodeShape{ofm_shape};
570 loco::NodeShape infer_depth_to_space(const luci::CircleDepthToSpace *node)
572 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
573 LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
575 // Only data format NHWC is supported
576 // TODO need to clarify what to do with layout in this operator
577 int32_t height = input_shape.dim(1).value();
578 int32_t width = input_shape.dim(2).value();
579 int32_t depth = input_shape.dim(3).value();
581 int block_size = node->block_size();
584 INTERNAL_EXN("Block size must be >= 2");
586 if (depth % (block_size * block_size))
588 INTERNAL_EXN("The input tensor's depth must be divisible by block_size^2");
591 loco::TensorShape output_shape;
592 output_shape.rank(4);
594 output_shape.dim(0) = input_shape.dim(0).value();
595 output_shape.dim(1) = height * block_size;
596 output_shape.dim(2) = width * block_size;
597 output_shape.dim(3) = depth / (block_size * block_size);
599 return loco::NodeShape{output_shape};
602 loco::NodeShape infer_depthwise_conv2d(const luci::CircleDepthwiseConv2D *node)
604 auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
605 auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM
607 assert(ifm_shape.rank() == 4);
608 assert(ker_shape.rank() == 4);
609 assert(ker_shape.dim(0).value() == 1);
611 auto os = infer_conv2d_type(node);
613 loco::TensorShape ofm_shape;
615 ofm_shape.dim(0) = ifm_shape.dim(0);
616 ofm_shape.dim(1) = os.height;
617 ofm_shape.dim(2) = os.width;
618 ofm_shape.dim(3) = ker_shape.dim(3);
620 return loco::NodeShape{ofm_shape};
623 loco::NodeShape infer_expand_dims(const luci::CircleExpandDims *node)
625 const loco::DataType S32 = loco::DataType::S32;
626 auto x_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
627 if (x_shape.rank() == 0)
629 // This maybe for unknown shape. We use shape from the node itself.
630 return use_own(node);
632 auto const_axis = loco::must_cast<luci::CircleConst *>(node->axis());
633 LUCI_ASSERT(const_axis->dtype() == S32, "Only support int32 CircleConst for axis");
634 if (const_axis->rank() != 0 && const_axis->rank() != 1)
636 INTERNAL_EXN_V("Non-scalar axis in OP", node->opnum());
638 int32_t axis = const_axis->at<S32>(0);
639 LUCI_ASSERT((axis <= static_cast<int32_t>(x_shape.rank())) &&
640 (axis >= -1 - static_cast<int32_t>(x_shape.rank())),
641 "Axis has to be between [-(D+1), D], where D is rank of input.");
642 size_t positive_axis = axis < 0 ? x_shape.rank() + axis + 1 : axis;
643 loco::TensorShape output_shape;
644 output_shape.rank(x_shape.rank() + 1);
646 for (; i < positive_axis; i++)
647 output_shape.dim(i) = x_shape.dim(i);
648 output_shape.dim(i) = loco::Dimension(1);
649 for (; i < x_shape.rank(); i++)
650 output_shape.dim(i + 1) = x_shape.dim(i);
651 return loco::NodeShape{output_shape};
654 loco::NodeShape infer_fill(const luci::CircleFill *node)
656 loco::TensorShape shape;
658 LUCI_ASSERT(node->dims(), "dims input should not be nullptr");
660 auto dims_node = dynamic_cast<luci::CircleConst *>(node->dims());
661 if (dims_node != nullptr)
663 // Only support node with S32
664 LUCI_ASSERT(dims_node->dtype() == loco::DataType::S32, "Only support int32 CircleConst");
666 if (dims_node->rank() != 1)
667 INTERNAL_EXN_V("Only support rank 1 CircleConst", oops::to_uint32(dims_node->rank()));
669 shape.rank(dims_node->dim(0).value());
671 for (uint32_t axis = 0; axis < shape.rank(); ++axis)
673 shape.dim(axis) = dims_node->at<loco::DataType::S32>(axis);
678 shape = own_shape(node);
682 return loco::NodeShape{shape};
685 loco::NodeShape infer_fully_connected(const luci::CircleFullyConnected *node)
687 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
688 auto weights_shape = loco::shape_get(node->weights()).as<loco::TensorShape>();
690 // Checking shape capability for fully connected layer
691 // Input: a tensor of at least rank 2 [D1, D2, ... Dn]
692 // Weight: [# of units, K]
693 // Output: [D1 * D2 * ... * Dn / K, # of units]
694 if (input_shape.rank() < 2 || weights_shape.rank() != 2)
696 // Return node own shape if shape inference is not possible
697 return use_own(node);
700 uint32_t input_size = 1;
701 for (uint32_t i = 0; i < input_shape.rank(); i++)
703 input_size = input_size * input_shape.dim(i).value();
705 const uint32_t batch_size = input_size / weights_shape.dim(1).value();
706 loco::TensorShape out_shape;
708 out_shape.dim(0) = batch_size;
709 out_shape.dim(1) = weights_shape.dim(0);
711 return loco::NodeShape{out_shape};
714 loco::NodeShape infer_gather(const luci::CircleGather *node)
716 loco::TensorShape output_shape;
718 const auto input_shape = loco::shape_get(node->params()).as<loco::TensorShape>();
719 const auto positions_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
720 int32_t axis = node->axis();
722 // If CircleGather input has a dynamic shape, it can't inference this shape. So, it returns the
723 // shape that node already has.
724 if (input_shape.rank() == 0 || positions_shape.rank() == 0)
725 return use_own(node);
728 axis += input_shape.rank();
730 output_shape.rank(input_shape.rank() - 1 + positions_shape.rank());
731 int32_t outdim_index = 0;
732 for (int32_t i = 0; i < axis; ++i)
733 output_shape.dim(outdim_index++) = input_shape.dim(i);
734 for (uint32_t i = 0; i < positions_shape.rank(); ++i)
735 output_shape.dim(outdim_index++) = positions_shape.dim(i);
736 for (uint32_t i = axis + 1; i < input_shape.rank(); ++i)
737 output_shape.dim(outdim_index++) = input_shape.dim(i);
739 return loco::NodeShape{output_shape};
742 loco::NodeShape infer_gather_nd(const luci::CircleGatherNd *node)
744 loco::TensorShape output_shape;
746 const auto params_shape = loco::shape_get(node->params()).as<loco::TensorShape>();
747 const auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
749 const auto params_rank = params_shape.rank();
750 const auto indices_rank = indices_shape.rank();
752 // see https://www.tensorflow.org/api_docs/python/tf/gather_nd
753 // output.shape = indices.shape[:-1] + params.shape[indices.shape[-1]:]
754 // batch_dims isn't supported in tflite
756 // TODO: replace exceptions with setting shape to unknown?
758 if (!indices_shape.dim(indices_rank - 1).known())
759 INTERNAL_EXN("Last indices dimension is unknown");
761 auto indices_last_dim = indices_shape.dim(indices_rank - 1).value();
763 if (indices_last_dim > params_rank)
764 INTERNAL_EXN("Last indices dimension should be <= params rank");
766 const uint32_t output_rank = indices_rank + params_rank - indices_last_dim - 1;
768 output_shape.rank(output_rank);
770 uint32_t output_index = 0;
771 for (uint32_t i = 0; i < indices_rank - 1; ++i)
773 auto &dim = indices_shape.dim(i);
775 INTERNAL_EXN("Unknown indices dimension is unsupported");
776 output_shape.dim(output_index++).set(dim.value());
779 for (uint32_t i = indices_last_dim; i < params_rank; ++i)
781 auto &dim = params_shape.dim(i);
783 INTERNAL_EXN("Unknown params dimension is unsupported");
784 output_shape.dim(output_index++).set(dim.value());
787 return loco::NodeShape{output_shape};
790 loco::NodeShape infer_matrix_diag(const luci::CircleMatrixDiag *node)
792 loco::TensorShape output_shape;
794 auto diagonal_shape = loco::shape_get(node->diagonal()).as<loco::TensorShape>();
795 auto rank = diagonal_shape.rank();
797 output_shape.rank(rank + 1);
799 for (uint32_t i = 0; i < rank; i++)
801 output_shape.dim(i) = diagonal_shape.dim(i);
804 output_shape.dim(rank) = diagonal_shape.dim(rank - 1);
806 return loco::NodeShape{output_shape};
809 loco::NodeShape infer_matrix_set_diag(const luci::CircleMatrixSetDiag *node)
811 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
812 auto diagonal_shape = loco::shape_get(node->diagonal()).as<loco::TensorShape>();
814 auto rank = diagonal_shape.rank();
816 LUCI_ASSERT(rank == input_shape.rank() - 1, "diagonal rank = input rank - 1");
818 for (uint32_t i = 0; i < rank - 1; i++)
820 LUCI_ASSERT(diagonal_shape.dim(i) == input_shape.dim(i), "diagonal dims = input dims");
823 auto dim = std::min(input_shape.dim(rank - 1).value(), input_shape.dim(rank).value());
825 LUCI_ASSERT(dim == diagonal_shape.dim(rank - 1), "Max diag len error");
827 return loco::NodeShape{input_shape};
830 loco::TensorShape infer_reducer(const loco::Node *input, const loco::Node *indices, bool keep_dims)
832 const loco::DataType S32 = loco::DataType::S32;
834 auto input_shape = loco::shape_get(input).as<loco::TensorShape>();
835 auto reduction_indices = loco::must_cast<const luci::CircleConst *>(indices);
838 // TODO support non-const case
839 // TODO support other data type
840 LUCI_ASSERT(reduction_indices->dtype() == S32, "Only support int 32");
843 std::vector<int32_t> reduction_values;
845 for (uint32_t i = 0; i < reduction_indices->size<S32>(); ++i)
847 int32_t axis = reduction_indices->at<S32>(i);
849 axis += input_shape.rank();
850 if (not(0 <= axis and axis < static_cast<int32_t>(input_shape.rank())))
851 INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis));
852 reduction_values.push_back(axis);
855 loco::TensorShape output_shape;
859 output_shape.rank(input_shape.rank());
860 for (uint32_t i = 0; i < input_shape.rank(); ++i)
861 output_shape.dim(i) = input_shape.dim(i);
862 for (uint32_t i = 0; i < reduction_values.size(); ++i)
863 output_shape.dim(reduction_values.at(i)) = 1;
867 std::vector<bool> check_reduce(input_shape.rank(), false);
868 for (uint32_t i = 0; i < reduction_values.size(); ++i)
869 check_reduce.at(reduction_values.at(i)) = true;
871 uint32_t reduce_cnt = 0;
872 for (uint32_t i = 0; i < check_reduce.size(); ++i)
873 if (check_reduce.at(i))
876 output_shape.rank(input_shape.rank() - reduce_cnt);
877 for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
878 if (check_reduce.at(i) == false)
879 output_shape.dim(j++) = input_shape.dim(i);
885 loco::NodeShape infer_mirror_pad(const luci::CircleMirrorPad *node)
887 // TODO support non-const case
888 auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
889 return use_paddings(node, paddings);
892 loco::NodeShape infer_one_hot(const luci::CircleOneHot *node)
894 const loco::DataType S32 = loco::DataType::S32;
895 auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
896 // Only support OneHot node's depth() is CircleConst with type S32
897 // TODO support depth with other types
898 auto depth = loco::must_cast<luci::CircleConst *>(node->depth());
899 LUCI_ASSERT(depth->dtype() == S32, "Only support int32 CircleConst");
900 if (depth->rank() != 0)
901 INTERNAL_EXN_V("Only support rank 0 CircleOneHot in Depth", oops::to_uint32(depth->rank()));
902 loco::TensorShape output_shape;
903 output_shape.rank(indices_shape.rank() + 1);
904 auto axis = node->axis();
906 axis += indices_shape.rank() + 1;
907 LUCI_ASSERT(0 <= axis, "Axis is out of range");
908 LUCI_ASSERT(static_cast<uint32_t>(axis) <= indices_shape.rank(), "Axis is out of range");
910 for (uint32_t i = 0; i < output_shape.rank(); i++)
912 if (i == static_cast<uint32_t>(axis))
914 output_shape.dim(i) = depth->at<S32>(0);
918 output_shape.dim(i) = indices_shape.dim(j++);
921 return loco::NodeShape{output_shape};
924 loco::NodeShape infer_pack(const luci::CirclePack *node)
926 LUCI_ASSERT(node->values_count() > 0, "Only support one or more inputs");
928 auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>();
929 // Make sure all inputs have the same shape.
930 for (uint32_t i = 1; i < node->values_count(); ++i)
932 auto in_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>();
933 LUCI_ASSERT(loco::NodeShape{first_shape} == loco::NodeShape{in_shape},
934 "All inputs must have the same shape");
937 // Checking shape capability for pack layer
938 // Input: tensors [D1, D2, ... Dn]
940 // Output: [D1, D2, ... , D_K-1, n, D_K+1, ... Dn]
941 auto axis = node->axis();
943 axis += first_shape.rank() + 1;
945 LUCI_ASSERT(0 <= axis, "Axis is out of range");
946 LUCI_ASSERT(static_cast<uint32_t>(axis) <= first_shape.rank(), "Axis is out of range");
948 loco::TensorShape output_shape;
949 output_shape.rank(first_shape.rank() + 1);
952 for (uint32_t i = 0; i < output_shape.rank(); ++i)
954 if (i == static_cast<uint32_t>(axis))
956 output_shape.dim(i) = node->values_count();
960 output_shape.dim(i) = first_shape.dim(j++);
964 return loco::NodeShape{output_shape};
967 loco::NodeShape infer_pad(const luci::CirclePad *node)
969 // TODO support non-const case
970 auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
971 return use_paddings(node, paddings);
974 loco::NodeShape infer_pad_v2(const luci::CirclePadV2 *node)
976 // TODO support non-const case
977 auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
980 auto node_shape = own_shape(node);
981 return loco::NodeShape{node_shape};
983 return use_paddings(node, paddings);
986 loco::NodeShape infer_p_relu(const luci::CirclePRelu *node)
988 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
989 auto alpha_shape = loco::shape_get(node->alpha()).as<loco::TensorShape>();
991 auto output_shape = broadcast_shape(input_shape, alpha_shape);
993 return loco::NodeShape{output_shape};
996 loco::NodeShape infer_range(const luci::CircleRange *node)
998 loco::TensorShape output_shape;
999 output_shape.rank(1);
1001 auto start_node = dynamic_cast<luci::CircleConst *>(node->start());
1002 auto limit_node = dynamic_cast<luci::CircleConst *>(node->limit());
1003 auto delta_node = dynamic_cast<luci::CircleConst *>(node->delta());
1005 if (start_node == nullptr || limit_node == nullptr || delta_node == nullptr)
1007 return use_own(node);
1010 double start = 0, limit = 0, delta = 0;
1012 #define GET_RANGE_PARAM(DT) \
1013 start = start_node->scalar<DT>(); \
1014 limit = limit_node->scalar<DT>(); \
1015 delta = delta_node->scalar<DT>();
1017 switch (start_node->dtype())
1019 case loco::DataType::FLOAT32:
1020 GET_RANGE_PARAM(loco::DataType::FLOAT32)
1022 case loco::DataType::S32:
1023 GET_RANGE_PARAM(loco::DataType::S32)
1026 INTERNAL_EXN("Range data type not supported");
1029 #undef GET_RANGE_PARAM
1032 INTERNAL_EXN("Delta can not be zero");
1034 output_shape.dim(0) = ceil((limit - start) / delta);
1036 return loco::NodeShape{output_shape};
1039 loco::NodeShape infer_reshape(const luci::CircleReshape *node)
1043 const loco::DataType S32 = loco::DataType::S32;
1045 loco::TensorShape shape_by_input;
1047 LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr");
1049 // Only support node's shape() is CircleConst with S32
1050 // TODO support other node with other types
1051 auto const_shape_node = dynamic_cast<luci::CircleConst *>(node->shape());
1052 if (const_shape_node != nullptr)
1054 LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst");
1056 shape_by_input.rank(const_shape_node->size<S32>());
1058 for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
1060 shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
1065 // We use shape from the node itself
1066 shape_by_input = own_shape(node);
1070 loco::TensorShape shape_by_attr;
1072 shape_by_attr.rank(node->newShape()->rank());
1074 for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
1076 shape_by_attr.dim(axis) = node->newShape()->dim(axis);
1080 if (!(shape_by_input == shape_by_attr))
1082 INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl;
1083 INFO(l) << " shape_by_input : " << shape_by_input << std::endl;
1084 INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl;
1087 loco::TensorShape output_shape = shape_by_input;
1089 // One of the dimensions can have special value -1, meaning its actual value should be inferred.
1090 const auto input_shape = loco::shape_get(node->tensor()).as<loco::TensorShape>();
1091 const uint32_t input_element_count = loco::element_count(&input_shape);
1092 uint32_t output_element_count = 1;
1093 uint32_t unknown_dim_index = UINT32_MAX;
1094 for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
1096 const uint32_t dim_value = output_shape.dim(dim_index).value();
1097 if (static_cast<int>(dim_value) == -1)
1099 LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
1100 unknown_dim_index = dim_index;
1104 output_element_count *= dim_value;
1107 if (unknown_dim_index != UINT32_MAX)
1109 output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
1112 return loco::NodeShape{output_shape};
1115 loco::NodeShape infer_resize_bilinear(const luci::CircleResizeBilinear *node)
1117 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1119 if (input_shape.rank() != 4)
1120 INTERNAL_EXN("Expected ResizeBilinear input to have rank 4");
1122 auto *const_node = loco::must_cast<luci::CircleConst *>(node->size());
1124 if (const_node->dtype() != loco::DataType::S32)
1125 INTERNAL_EXN("Only S32 datatype is supported for ResizeBilinear size");
1127 if (const_node->rank() != 1)
1128 INTERNAL_EXN("Expected size tensor of rank 1");
1130 if (const_node->dim(0).value() != 2)
1131 INTERNAL_EXN("Expected size tensor with shape [2]");
1133 loco::TensorShape output_shape;
1134 output_shape.rank(4);
1135 output_shape.dim(0) = input_shape.dim(0);
1136 output_shape.dim(1) = const_node->at<loco::DataType::S32>(0);
1137 output_shape.dim(2) = const_node->at<loco::DataType::S32>(1);
1138 output_shape.dim(3) = input_shape.dim(3);
1140 return loco::NodeShape{output_shape};
1143 loco::NodeShape infer_resize_nearest_neighbor(const luci::CircleResizeNearestNeighbor *node)
1145 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1147 if (input_shape.rank() != 4)
1148 INTERNAL_EXN("Expected ResizeNearesNeighbor input to have rank 4");
1150 auto *const_node = loco::must_cast<luci::CircleConst *>(node->size());
1152 if (const_node->dtype() != loco::DataType::S32)
1153 INTERNAL_EXN("Only S32 datatype is supported for ResizeNearesNeighbor size");
1155 if (const_node->rank() != 1)
1156 INTERNAL_EXN("Expected size tensor of rank 1");
1158 if (const_node->dim(0).value() != 2)
1159 INTERNAL_EXN("Expected size tensor with shape [2]");
1161 loco::TensorShape output_shape;
1162 output_shape.rank(4);
1163 output_shape.dim(0) = input_shape.dim(0);
1164 output_shape.dim(1) = const_node->at<loco::DataType::S32>(0);
1165 output_shape.dim(2) = const_node->at<loco::DataType::S32>(1);
1166 output_shape.dim(3) = input_shape.dim(3);
1168 return loco::NodeShape{output_shape};
1171 loco::NodeShape infer_scatter_nd(const luci::CircleScatterNd *node)
1173 loco::TensorShape output_shape;
1175 auto shape_node = loco::must_cast<luci::CircleConst *>(node->shape());
1177 const loco::DataType S32 = loco::DataType::S32;
1178 const loco::DataType S64 = loco::DataType::S64;
1180 std::vector<int64_t> vect_shape;
1182 if (shape_node->dtype() == S32)
1183 vect_shape = vector_from_constant<S32>(shape_node);
1184 else if (shape_node->dtype() == S64)
1185 vect_shape = vector_from_constant<S64>(shape_node);
1187 LUCI_ASSERT(false, "Only support int32/int64 for shape()");
1189 output_shape.rank(vect_shape.size());
1190 for (uint32_t i = 0; i < vect_shape.size(); ++i)
1191 output_shape.dim(i) = vect_shape[i];
1193 return loco::NodeShape{output_shape};
1196 loco::NodeShape infer_segment_sum(const luci::CircleSegmentSum *node)
1198 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1199 auto segment_shape = loco::shape_get(node->segment_ids()).as<loco::TensorShape>();
1201 LUCI_ASSERT(segment_shape.rank() == 1, "segment_ids must be 1-D tensor");
1202 LUCI_ASSERT(segment_shape.dim(0).value() == input_shape.dim(0).value(),
1203 "segment_ids size must be equal to the size of data's first dimension");
1205 auto ids_shape_value = loco::must_cast<luci::CircleConst *>(node->segment_ids());
1207 std::vector<int64_t> vect_ids;
1209 if (ids_shape_value->dtype() == loco::DataType::S32)
1210 vect_ids = vector_from_constant<loco::DataType::S32>(ids_shape_value);
1212 LUCI_ASSERT(std::is_sorted(vect_ids.begin(), vect_ids.end()),
1213 "segment_ids values should be sorted")
1215 loco::TensorShape output_shape;
1217 output_shape.rank(input_shape.rank());
1219 for (uint32_t i = 1; i < input_shape.rank(); ++i)
1220 output_shape.dim(i) = input_shape.dim(i);
1222 output_shape.dim(0) = vect_ids.back() + 1;
1224 return loco::NodeShape{output_shape};
1227 loco::NodeShape infer_select(const luci::CircleSelect *node)
1229 auto t_shape = loco::shape_get(node->t()).as<loco::TensorShape>();
1230 assert(t_shape == loco::shape_get(node->e()).as<loco::TensorShape>());
1232 // condition shape validation
1233 auto c_shape = loco::shape_get(node->condition()).as<loco::TensorShape>();
1234 if (c_shape.rank() != t_shape.rank())
1236 if (c_shape.rank() != 0 && c_shape.rank() != 1)
1237 INTERNAL_EXN_V("CircleSelect condition rank is not 0 nor 1: ", c_shape.rank());
1239 if (c_shape.rank() == 1)
1241 if (c_shape.dim(0).value() != t_shape.dim(0).value())
1242 INTERNAL_EXN("CircleSelect condition dim(0) should match with t.dim(0)");
1246 return loco::NodeShape{t_shape};
1249 loco::NodeShape infer_select_v2(const luci::CircleSelectV2 *node)
1251 auto c_shape = loco::shape_get(node->condition()).as<loco::TensorShape>();
1252 auto t_shape = loco::shape_get(node->t()).as<loco::TensorShape>();
1253 auto e_shape = loco::shape_get(node->e()).as<loco::TensorShape>();
1255 // validate ability to broadcast shapes to each other
1256 auto b_shape = broadcast_shape(broadcast_shape(c_shape, t_shape), e_shape);
1257 return loco::NodeShape{b_shape};
1260 loco::NodeShape infer_shape(const luci::CircleShape *node)
1262 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1264 loco::TensorShape output_shape;
1266 output_shape.rank(1);
1267 output_shape.dim(0) = input_shape.rank();
1269 return loco::NodeShape{output_shape};
1272 loco::NodeShape infer_slice(const luci::CircleSlice *node)
1274 const loco::DataType S32 = loco::DataType::S32;
1275 const loco::DataType S64 = loco::DataType::S64;
1277 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1279 auto const_begin = loco::must_cast<luci::CircleConst *>(node->begin());
1280 auto const_size = loco::must_cast<luci::CircleConst *>(node->size());
1282 loco::TensorShape output_shape;
1283 std::vector<int64_t> vect_begin; // to hold both S32/S64, we use int64_t
1284 std::vector<int64_t> vect_size;
1286 if (const_begin->dtype() == S32)
1287 vect_begin = vector_from_constant<S32>(const_begin);
1288 else if (const_begin->dtype() == S64)
1289 vect_begin = vector_from_constant<S64>(const_begin);
1291 LUCI_ASSERT(false, "Only support int32/int64 for begin()");
1293 if (const_size->dtype() == S32)
1294 vect_size = vector_from_constant<S32>(const_size);
1295 else if (const_size->dtype() == S64)
1296 vect_size = vector_from_constant<S64>(const_size);
1298 LUCI_ASSERT(false, "Only support int32/int64 for size()");
1300 assert(input_shape.rank() == vect_begin.size());
1301 assert(input_shape.rank() == vect_size.size());
1303 output_shape.rank(vect_begin.size());
1304 for (uint32_t idx = 0; idx < vect_begin.size(); ++idx)
1306 auto size = vect_size.at(idx);
1309 size = input_shape.dim(idx).value() - vect_begin.at(idx);
1311 output_shape.dim(idx) = size;
1314 return loco::NodeShape{output_shape};
1317 loco::NodeShape infer_space_to_batch_nd(const luci::CircleSpaceToBatchND *node)
1319 const loco::DataType S32 = loco::DataType::S32;
1321 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1322 // Support only input rank is 3 and 4
1323 assert(input_shape.rank() == 3 || input_shape.rank() == 4);
1325 // Only support block_shape() with S32 type CircleConst for now
1326 auto const_block_shape = loco::must_cast<luci::CircleConst *>(node->block_shape());
1327 LUCI_ASSERT(const_block_shape->dtype() == S32, "Only support int32 block_shape");
1329 // Only support paddings() with S32 type CircleConst for now
1330 auto const_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1331 LUCI_ASSERT(const_paddings->dtype() == S32, "Only support int32 paddings");
1333 auto const_block_shape_shape = loco::shape_get(const_block_shape).as<loco::TensorShape>();
1334 auto const_paddings_shape = loco::shape_get(const_paddings).as<loco::TensorShape>();
1335 assert(const_block_shape_shape.rank() == 1);
1336 assert(const_paddings_shape.rank() == 2);
1338 int32_t input_spatial_dim = input_shape.rank() - 2;
1339 assert(const_block_shape_shape.dim(0) == input_spatial_dim);
1340 assert(const_paddings_shape.dim(0) == input_spatial_dim);
1341 assert(const_paddings_shape.dim(1) == 2);
1343 // Check all values of block_shape >= 1
1344 uint32_t ele_count = const_block_shape->size<S32>();
1345 for (uint32_t e = 0; e < ele_count; ++e)
1347 auto val = const_block_shape->at<S32>(e);
1350 INTERNAL_EXN_V("All values of block_shape >= 1: ", e);
1354 loco::TensorShape shape_output;
1356 shape_output.rank(input_shape.rank());
1358 int32_t output_batch_size = input_shape.dim(0).value();
1359 for (int32_t dim = 0; dim < input_spatial_dim; ++dim)
1361 int dim_size = input_shape.dim(dim + 1).value();
1362 dim_size += const_paddings->at<S32>(dim * 2);
1363 dim_size += const_paddings->at<S32>(dim * 2 + 1);
1364 shape_output.dim(dim + 1) = dim_size / const_block_shape->at<S32>(dim);
1366 assert(dim_size % const_block_shape->at<S32>(dim) == 0);
1367 output_batch_size = output_batch_size * const_block_shape->at<S32>(dim);
1369 shape_output.dim(0) = output_batch_size;
1370 shape_output.dim(input_shape.rank() - 1) = input_shape.dim(input_shape.rank() - 1);
1372 return loco::NodeShape{shape_output};
1375 loco::NodeShape infer_space_to_depth(const luci::CircleSpaceToDepth *node)
1377 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1378 LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
1380 // Only data format NHWC is supported
1381 int32_t height = input_shape.dim(1).value();
1382 int32_t width = input_shape.dim(2).value();
1383 int32_t depth = input_shape.dim(3).value();
1385 int block_size = node->block_size();
1388 INTERNAL_EXN("Block size must be >= 2");
1390 if ((height % block_size) || (width % block_size))
1392 INTERNAL_EXN("The input tensor's height and width must be divisible by block_size");
1395 loco::TensorShape output_shape;
1396 output_shape.rank(4);
1398 output_shape.dim(0) = input_shape.dim(0).value();
1399 output_shape.dim(1) = height / block_size;
1400 output_shape.dim(2) = width / block_size;
1401 output_shape.dim(3) = block_size * block_size * depth;
1403 return loco::NodeShape{output_shape};
1406 loco::NodeShape infer_sparse_to_dense(const luci::CircleSparseToDense *node)
1408 loco::TensorShape shape;
1410 LUCI_ASSERT(node->output_shape(), "dims input should not be nullptr");
1412 auto output_shape_node = dynamic_cast<luci::CircleConst *>(node->output_shape());
1413 if (output_shape_node != nullptr)
1415 // Only support node with S32
1416 LUCI_ASSERT(output_shape_node->dtype() == loco::DataType::S32,
1417 "Only support int32 CircleConst");
1419 if (output_shape_node->rank() != 1)
1420 INTERNAL_EXN_V("Only support rank 1 CircleConst",
1421 oops::to_uint32(output_shape_node->rank()));
1423 shape.rank(output_shape_node->size<loco::DataType::S32>());
1425 for (uint32_t axis = 0; axis < shape.rank(); ++axis)
1427 shape.dim(axis) = output_shape_node->at<loco::DataType::S32>(axis);
1432 shape = own_shape(node);
1436 return loco::NodeShape{shape};
1439 loco::NodeShape infer_strided_slice(const luci::CircleStridedSlice *node)
1441 auto begin_node = dynamic_cast<luci::CircleConst *>(node->begin());
1442 auto end_node = dynamic_cast<luci::CircleConst *>(node->end());
1443 auto strides_node = dynamic_cast<luci::CircleConst *>(node->strides());
1445 if (begin_node == nullptr || end_node == nullptr || strides_node == nullptr)
1447 return use_own(node);
1450 loco::TensorShape shape = infer_output_shape(node);
1451 return loco::NodeShape{shape};
1454 loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node)
1456 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1458 // TODO input shape may be unknown before runtime
1459 std::vector<bool> do_squeeze(input_shape.rank(), false);
1460 uint32_t num_squeezed = 0;
1462 if (!node->squeeze_dims().empty())
1464 // SqueezeDims not empty, squeeze only dims specified
1465 for (int32_t raw_dim : node->squeeze_dims())
1467 int32_t dim = raw_dim < 0 ? raw_dim + input_shape.rank() : raw_dim;
1469 if (dim < 0 || static_cast<uint32_t>(dim) >= input_shape.rank() ||
1470 input_shape.dim(dim).value() != 1)
1472 INTERNAL_EXN("invalid dimention specified to Squeeze");
1475 if (!do_squeeze[dim])
1477 do_squeeze[dim] = true;
1482 // SqueezeDims empty, squeeze any dims with size == 1
1483 for (uint32_t dim = 0; dim < input_shape.rank(); ++dim)
1485 if (input_shape.dim(dim) == 1)
1487 do_squeeze[dim] = true;
1493 loco::TensorShape output_shape;
1494 output_shape.rank(input_shape.rank() - num_squeezed);
1496 for (uint32_t in_dim = 0, out_dim = 0; in_dim < input_shape.rank(); ++in_dim)
1498 if (!do_squeeze[in_dim])
1500 output_shape.dim(out_dim++) = input_shape.dim(in_dim);
1504 return loco::NodeShape{output_shape};
1507 loco::NodeShape infer_tile(const luci::CircleTile *node)
1509 const loco::DataType S32 = loco::DataType::S32;
1511 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1512 auto multiples = loco::must_cast<luci::CircleConst *>(node->multiples());
1514 // TODO support non-const case
1515 // TODO support S64 type
1516 LUCI_ASSERT(multiples->dtype() == S32, "Only support int32 multiples");
1517 LUCI_ASSERT(multiples->rank() == 1, "multiples should be rank 1")
1519 uint32_t n = multiples->dim(0).value();
1521 LUCI_ASSERT(n == input_shape.rank(), "length of multiples should be the same with input rank");
1523 loco::TensorShape output_shape;
1525 output_shape.rank(input_shape.rank());
1526 for (uint32_t ni = 0; ni < n; ++ni)
1528 int32_t multiple = multiples->at<S32>(ni);
1529 output_shape.dim(ni) = input_shape.dim(ni).value() * static_cast<uint32_t>(multiple);
1532 return loco::NodeShape{output_shape};
1535 loco::NodeShape infer_transpose(const luci::CircleTranspose *node)
1537 auto input_shape = loco::shape_get(node->a()).as<loco::TensorShape>();
1539 auto perm_node = loco::must_cast<luci::CircleConst *>(node->perm());
1541 loco::TensorShape output_shape;
1542 output_shape.rank(input_shape.rank());
1544 assert(perm_node->dtype() == loco::DataType::S32);
1545 assert(input_shape.rank() == perm_node->template size<loco::DataType::S32>());
1547 for (uint32_t out_axis = 0; out_axis < output_shape.rank(); out_axis++)
1549 auto in_axis = perm_node->template at<loco::DataType::S32>(out_axis);
1550 output_shape.dim(out_axis) = input_shape.dim(in_axis);
1553 return output_shape;
1556 loco::NodeShape infer_transpose_conv(const luci::CircleTransposeConv *node)
1558 // TransposeConv's output shape is written in its 'inputSizes' argument
1559 auto input_sizes_const = loco::must_cast<luci::CircleConst *>(node->inputSizes());
1560 // TODO support non-const type
1561 LUCI_ASSERT(input_sizes_const->dtype() == loco::DataType::S32, "Only support S32 dtype")
1562 LUCI_ASSERT(input_sizes_const->rank() == 1 && input_sizes_const->dim(0).value() == 4,
1563 "Only support rank 1 with 4 entries")
1565 loco::TensorShape shape;
1568 for (uint32_t axis = 0; axis < 4; ++axis)
1569 shape.dim(axis) = input_sizes_const->at<loco::DataType::S32>(axis);
1571 return loco::NodeShape{shape};
1574 loco::NodeShape infer_unpack(const luci::CircleUnpack *node)
1576 // CircleUnpack provides list(array) of Tensors which has one less dimension of the input
1577 // We'll set shape of CircleUnpack to shape of actual outputs
1578 // TODO fix this if any problem rises
1579 auto value_shape = loco::shape_get(node->value()).as<loco::TensorShape>();
1581 auto axis = node->axis();
1582 auto num = node->num();
1583 auto rank = static_cast<int32_t>(value_shape.rank());
1588 return use_own(node);
1591 LUCI_ASSERT(-rank <= axis && axis < rank, "Axis is out of range");
1596 LUCI_ASSERT(num == static_cast<int32_t>(value_shape.dim(axis).value()),
1597 "num, axis maybe incorrect");
1599 loco::TensorShape output_shape;
1600 output_shape.rank(rank - 1);
1602 for (int32_t i = 0, o = 0; i < rank; ++i)
1605 output_shape.dim(o++) = value_shape.dim(i);
1608 return loco::NodeShape{output_shape};
1611 loco::NodeShape infer_unidirectionalsequencelstm(const luci::CircleUnidirectionalSequenceLSTM *node)
1613 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1614 auto recurrent_to_output_weights =
1615 loco::shape_get(node->recurrent_to_output_weights()).as<loco::TensorShape>();
1616 auto rank = input_shape.rank();
1617 loco::TensorShape output_shape;
1618 output_shape.rank(rank);
1619 for (uint32_t i = 0; i < rank - 1; i++)
1621 output_shape.dim(i) = input_shape.dim(i);
1623 output_shape.dim(rank - 1) = recurrent_to_output_weights.dim(1);
1624 return loco::NodeShape{output_shape};
1627 loco::NodeShape infer_unique(const luci::CircleUnique *node)
1629 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1631 assert(input_shape.rank() == 1);
1633 loco::TensorShape shape_output;
1634 shape_output = own_shape(node);
1636 return loco::NodeShape{shape_output};
1640 loco::NodeShape infer_bcq_fully_connected(const luci::CircleBCQFullyConnected *node)
1642 loco::TensorShape out_shape;
1644 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1645 auto weights_clusters = loco::must_cast<luci::CircleConst *>(node->weights_clusters());
1647 LUCI_ASSERT(input_shape.rank() == 2, "Input rank of BCQFullyConnected should be 2");
1649 int32_t qbits_sum = 0;
1650 for (uint32_t i = 0; i < weights_clusters->dim(0).value(); ++i)
1652 qbits_sum += weights_clusters->at<loco::DataType::S32>(i * 2 + 1);
1656 out_shape.dim(0) = qbits_sum;
1657 out_shape.dim(1) = input_shape.dim(1);
1659 return loco::NodeShape{out_shape};
1662 loco::NodeShape infer_bcq_gather(const luci::CircleBCQGather *node)
1664 loco::TensorShape input_shape;
1665 loco::TensorShape output_shape;
1667 const auto input_binary_shape = loco::shape_get(node->input_binary()).as<loco::TensorShape>();
1668 const auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
1669 auto axis = node->axis();
1671 auto input_clusters = loco::must_cast<luci::CircleConst *>(node->input_clusters());
1673 for (uint32_t i = 0; i < input_clusters->dim(0).value(); ++i)
1675 qbits_sum += input_clusters->at<loco::DataType::S32>(i * 2 + 1);
1678 input_shape.rank(2);
1679 input_shape.dim(0) = qbits_sum;
1680 input_shape.dim(1) = input_binary_shape.dim(1).value() * 32;
1682 output_shape.rank(input_shape.rank() - 1 + indices_shape.rank());
1683 int32_t outdim_index = 0;
1684 for (int32_t i = 0; i < axis; ++i)
1685 output_shape.dim(outdim_index++) = input_shape.dim(i);
1686 for (uint32_t i = 0; i < indices_shape.rank(); ++i)
1687 output_shape.dim(outdim_index++) = indices_shape.dim(i);
1688 for (uint32_t i = axis + 1; i < input_shape.rank(); ++i)
1689 output_shape.dim(outdim_index++) = input_shape.dim(i);
1691 return loco::NodeShape{output_shape};
1695 loco::NodeShape infer_input(const luci::CircleInput *node)
1697 loco::TensorShape shape;
1699 shape.rank(node->rank());
1700 for (uint32_t axis = 0; axis < node->rank(); axis++)
1701 shape.dim(axis) = node->dim(axis);
1703 return loco::NodeShape{shape};
1706 loco::NodeShape infer_output(const luci::CircleOutput *node)
1708 auto graph_outputs = node->graph()->outputs();
1709 auto graph_output = graph_outputs->at(node->index());
1710 auto output_shape = graph_output->shape();
1712 return loco::NodeShape{*output_shape};
1715 loco::NodeShape infer_if_out(const luci::CircleIfOut *node)
1718 * @note IF operator type and shape are that of the "then" and "else"
1721 auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input());
1722 if (circle_if == nullptr)
1724 INTERNAL_EXN("CircleIf IR is not configured correctly");
1727 auto index = node->index();
1728 auto then_graph = circle_if->then_graph();
1729 auto else_graph = circle_if->else_graph();
1730 assert(then_graph != nullptr);
1731 assert(else_graph != nullptr);
1733 // shape and type are assumed to be same
1734 // these are checked at post_import_graph() in Import
1735 auto then_outputs = loco::output_nodes(then_graph);
1736 auto else_outputs = loco::output_nodes(else_graph);
1737 assert(then_outputs.size() == else_outputs.size());
1738 assert(index < static_cast<int32_t>(then_outputs.size()));
1740 auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index));
1741 auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index));
1743 auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items
1744 auto else_graph_outputs = else_graph->outputs();
1745 assert(then_graph_outputs->size() == else_graph_outputs->size());
1747 auto then_graph_output = then_graph_outputs->at(then_out->index());
1748 auto else_graph_output = else_graph_outputs->at(else_out->index());
1749 (void)else_graph_output; // make compiler happy for unused variable warnings
1750 assert(*then_graph_output->shape() == *else_graph_output->shape());
1752 return loco::NodeShape{*then_graph_output->shape()};
1755 loco::NodeShape infer_non_max_suppression_v4_out(const luci::CircleNonMaxSuppressionV4Out *node)
1757 const loco::DataType S32 = loco::DataType::S32;
1759 auto nmsv4 = dynamic_cast<const luci::CircleNonMaxSuppressionV4 *>(node->input());
1760 if (nmsv4 == nullptr)
1761 INTERNAL_EXN("CircleNonMaxSuppressionV4 IR is not configured correctly");
1763 auto index = node->index();
1765 return loco::TensorShape({0});
1769 auto unknown = loco::TensorShape{loco::Dimension()};
1770 auto max_output_size = dynamic_cast<const luci::CircleConst *>(nmsv4->max_output_size());
1771 if (max_output_size == nullptr)
1772 return unknown; // we need CircleConst for max output size
1774 LUCI_ASSERT(max_output_size->dtype() == S32, "Only support int32 for max_output_size");
1776 if (max_output_size->size<S32>() < 1)
1779 auto max_output_size_value = uint32_t(max_output_size->at<S32>(0));
1780 return loco::TensorShape{max_output_size_value};
1783 loco::NodeShape infer_non_max_suppression_v5_out(const luci::CircleNonMaxSuppressionV5Out *node)
1785 const loco::DataType S32 = loco::DataType::S32;
1787 auto nmsv5 = dynamic_cast<const luci::CircleNonMaxSuppressionV5 *>(node->input());
1788 if (nmsv5 == nullptr)
1789 INTERNAL_EXN("CircleNonMaxSuppressionV5 IR is not configured correctly");
1791 auto index = node->index();
1793 return loco::TensorShape({0});
1795 assert(index == 0 || index == 1);
1797 auto unknown = loco::TensorShape{loco::Dimension()};
1798 auto max_output_size = dynamic_cast<const luci::CircleConst *>(nmsv5->max_output_size());
1799 if (max_output_size == nullptr)
1800 return unknown; // we need CircleConst for max output size
1802 LUCI_ASSERT(max_output_size->dtype() == S32, "Only support int32 for max_output_size");
1804 if (max_output_size->size<S32>() < 1)
1807 auto max_output_size_value = uint32_t(max_output_size->at<S32>(0));
1808 return loco::TensorShape{max_output_size_value};
1811 loco::NodeShape infer_split_out(const luci::CircleSplitOut *node)
1813 const loco::DataType S32 = loco::DataType::S32;
1815 auto split = dynamic_cast<const luci::CircleSplit *>(node->input());
1816 if (split == nullptr)
1817 INTERNAL_EXN("CircleSplit IR is not configured correctly");
1819 loco::NodeShape unknown;
1821 auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
1823 auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
1824 if (split_dim == nullptr)
1825 return unknown; // we need CircleConst for split_dim
1826 LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
1828 assert(split_dim->size<S32>() == 1);
1829 auto split_dim_axis = split_dim->at<S32>(0);
1830 if (split_dim_axis < 0)
1831 split_dim_axis += split_shape.rank();
1833 auto split_dim_value = split_shape.dim(split_dim_axis).value();
1834 assert(split_dim_value % split->num_split() == 0);
1835 const int split_depth = split_dim_value / split->num_split();
1837 loco::TensorShape output_shape = split_shape;
1839 // All shapes are equally same
1840 output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
1842 return loco::NodeShape{output_shape};
1845 loco::NodeShape infer_split_v_out(const luci::CircleSplitVOut *node)
1847 const loco::DataType S32 = loco::DataType::S32;
1849 auto split = dynamic_cast<const luci::CircleSplitV *>(node->input());
1850 if (split == nullptr)
1851 INTERNAL_EXN("CircleSplit IR is not configured correctly");
1853 loco::NodeShape unknown;
1855 auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
1857 auto size_splits = dynamic_cast<const luci::CircleConst *>(split->size_splits());
1858 if (size_splits == nullptr)
1859 return unknown; // we need CircleConst for size_splits
1860 LUCI_ASSERT(size_splits->dtype() == S32, "Only support int32 for size_splits");
1862 auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
1863 if (split_dim == nullptr)
1864 return unknown; // we need CircleConst for split_dim
1865 LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
1868 assert(split_dim->size<S32>() == 1);
1869 auto split_dim_axis = split_dim->at<S32>(0);
1870 if (split_dim_axis < 0)
1871 split_dim_axis += split_shape.rank();
1873 // interpret size_splits values
1874 int32_t size_splits_count = static_cast<int32_t>(size_splits->size<S32>());
1875 assert(size_splits_count == split->num_split());
1877 int64_t minus_one_count = 0, size_splits_sum = 0;
1878 for (int32_t idx = 0; idx < size_splits_count; ++idx)
1880 auto size = size_splits->at<S32>(idx);
1885 size_splits_sum += size;
1887 if (minus_one_count > 1)
1888 INTERNAL_EXN("CircleSplitV size_splits has more than two -1 values");
1890 // calcuate this SplitVOut shape
1891 auto input_size = split_shape.dim(split_dim_axis).value();
1892 assert(size_splits_sum <= input_size);
1894 auto index_this = node->index();
1895 assert(0 <= index_this && index_this < split->num_split());
1896 auto split_depth = size_splits->at<S32>(index_this);
1897 if (split_depth == -1)
1898 split_depth = input_size - size_splits_sum;
1900 loco::TensorShape output_shape = split_shape;
1902 output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
1904 return loco::NodeShape{output_shape};
1907 loco::NodeShape infer_top_k_v2_out(const luci::CircleTopKV2Out *node)
1909 const loco::DataType S32 = loco::DataType::S32;
1911 auto topkv2 = dynamic_cast<const luci::CircleTopKV2 *>(node->input());
1912 if (topkv2 == nullptr)
1913 INTERNAL_EXN("CircleSplit IR is not configured correctly");
1915 // shape of topkv2 is same as topkv2->input()
1916 auto input_shape = loco::shape_get(topkv2).as<loco::TensorShape>();
1918 auto node_k = loco::must_cast<const luci::CircleConst *>(topkv2->k());
1919 LUCI_ASSERT(node_k->dtype() == S32, "Only support Int32");
1920 assert(node_k->size<S32>() == 1);
1922 loco::TensorShape output_shape;
1924 output_shape.rank(input_shape.rank());
1925 for (uint32_t idx = 0; idx < input_shape.rank() - 1; ++idx)
1927 output_shape.dim(idx) = input_shape.dim(idx);
1929 output_shape.dim(input_shape.rank() - 1) = node_k->at<S32>(0);
1931 return loco::NodeShape{output_shape};
1934 loco::NodeShape infer_unique_out(const luci::CircleUniqueOut *node)
1936 if (node->index() == 0)
1938 auto unique_shape = own_shape(node);
1939 return loco::NodeShape{unique_shape};
1941 assert(node->index() == 1);
1942 auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
1943 auto unique_shape = loco::shape_get(unique->input()).as<loco::TensorShape>();
1945 assert(unique_shape.rank() == 1);
1947 loco::TensorShape shape_output;
1948 shape_output.rank(1);
1949 shape_output.dim(0) = unique_shape.dim(0);
1950 return loco::NodeShape{shape_output};
1953 loco::NodeShape infer_unpack_out(const luci::CircleUnpackOut *node)
1955 auto unpack = dynamic_cast<const luci::CircleUnpack *>(node->input());
1956 if (unpack == nullptr)
1958 INTERNAL_EXN("CircleUnpack IR is not configured correctly");
1961 auto unpack_shape = loco::shape_get(unpack).as<loco::TensorShape>();
1963 return loco::NodeShape{unpack_shape};
1966 loco::NodeShape infer_while_out(const luci::CircleWhileOut *node)
1969 * @note WHILE operator's shape is the same with the "cond"
1972 auto circle_while = dynamic_cast<const luci::CircleWhile *>(node->input());
1973 if (circle_while == nullptr)
1975 INTERNAL_EXN("CircleWhile IR is not configured correctly");
1978 auto index = node->index();
1979 auto cond_graph = circle_while->cond_graph();
1980 assert(cond_graph != nullptr);
1982 // Assumption: the index of CircleWhileOut matches with the index of input nodes returned by
1983 // loco::input_nodes
1984 auto cond_inputs = loco::input_nodes(cond_graph);
1985 auto cond_in = loco::must_cast<luci::CircleInput *>(cond_inputs.at(index));
1987 auto cond_graph_inputs = cond_graph->inputs();
1988 auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
1990 auto cond_graph_input_shape = *cond_graph_input->shape();
1991 auto this_shape = own_shape(node);
1993 if (!(this_shape == cond_graph_input_shape))
1996 WARN(l) << "Warning: CircleWhileOut '" << node->name() << "' shape mispatch " << this_shape
1997 << " vs " << cond_graph_input_shape;
2000 return loco::NodeShape{this_shape};
2004 * @brief Class to infer the shape of CircleNode
2006 * @note All CircleNode's inputs and outputs are always loco::Domain::Tensor
2008 class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeShape>
2011 loco::NodeShape visit(const luci::CircleAbs *node) final { return use_x(node); }
2013 loco::NodeShape visit(const luci::CircleAdd *node) final { return broadcast_xy(node); }
2015 loco::NodeShape visit(const luci::CircleAddN *node) final { return infer_add_n(node); }
2017 loco::NodeShape visit(const luci::CircleArgMax *node) final { return infer_arg_max(node); }
2019 loco::NodeShape visit(const luci::CircleArgMin *node) final { return infer_arg_min(node); }
2021 loco::NodeShape visit(const luci::CircleAveragePool2D *node) final
2023 return infer_pool_2d_shape(node);
2026 loco::NodeShape visit(const luci::CircleBatchMatMul *node) final
2028 auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
2029 auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
2031 return infer_batchmatmul_shape(x_shape, y_shape, node->adj_x(), node->adj_y());
2034 loco::NodeShape visit(const luci::CircleBatchToSpaceND *node) final
2036 return infer_batch_to_space_nd(node);
2039 loco::NodeShape visit(const luci::CircleCast *node) final { return use_x(node); }
2041 loco::NodeShape visit(const luci::CircleCeil *node) final { return use_x(node); }
2043 loco::NodeShape visit(const luci::CircleConcatenation *node) final
2045 return infer_concatenation(node);
2048 loco::NodeShape visit(const luci::CircleConst *node) final { return use_own(node); }
2050 loco::NodeShape visit(const luci::CircleConv2D *node) final { return infer_conv2d(node); }
2052 loco::NodeShape visit(const luci::CircleCos *node) final { return use_x(node); }
2054 loco::NodeShape visit(const luci::CircleCustom *node) final { return use_own(node); }
2056 loco::NodeShape visit(const luci::CircleDepthToSpace *node) final
2058 return infer_depth_to_space(node);
2061 loco::NodeShape visit(const luci::CircleDepthwiseConv2D *node) final
2063 return infer_depthwise_conv2d(node);
2066 loco::NodeShape visit(const luci::CircleDequantize *node) final
2068 const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2069 return loco::NodeShape{input_shape};
2072 loco::NodeShape visit(const luci::CircleDiv *node) final { return broadcast_xy(node); }
2074 loco::NodeShape visit(const luci::CircleElu *node) final
2076 auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
2078 return loco::NodeShape{input_shape};
2081 loco::NodeShape visit(const luci::CircleEqual *node) final { return broadcast_xy(node); }
2083 loco::NodeShape visit(const luci::CircleExp *node) final { return use_x(node); }
2085 loco::NodeShape visit(const luci::CircleExpandDims *node) final
2087 return infer_expand_dims(node);
2090 loco::NodeShape visit(const luci::CircleFill *node) final { return infer_fill(node); }
2092 loco::NodeShape visit(const luci::CircleFloor *node) final { return use_x(node); }
2094 loco::NodeShape visit(const luci::CircleFloorDiv *node) final { return broadcast_xy(node); }
2096 loco::NodeShape visit(const luci::CircleFloorMod *node) final { return broadcast_xy(node); }
2098 loco::NodeShape visit(const luci::CircleFullyConnected *node) final
2100 return infer_fully_connected(node);
2103 loco::NodeShape visit(const luci::CircleGather *node) final { return infer_gather(node); }
2105 loco::NodeShape visit(const luci::CircleGatherNd *node) final { return infer_gather_nd(node); }
2107 loco::NodeShape visit(const luci::CircleGreater *node) final { return broadcast_xy(node); }
2109 loco::NodeShape visit(const luci::CircleGreaterEqual *node) final { return broadcast_xy(node); }
2111 loco::NodeShape visit(const luci::CircleIf *node) final
2113 // Shape of CircleIf is not used. Just use input 0
2114 assert(node->input_count() > 0);
2115 const auto input_shape = loco::shape_get(node->input(0)).as<loco::TensorShape>();
2116 return loco::NodeShape{input_shape};
2119 loco::NodeShape visit(const luci::CircleL2Normalize *node) final { return use_x(node); }
2121 loco::NodeShape visit(const luci::CircleL2Pool2D *node) final
2123 return infer_pool_2d_shape(node);
2126 loco::NodeShape visit(const luci::CircleLeakyRelu *node) final
2128 const auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
2129 return loco::NodeShape{input_shape};
2132 loco::NodeShape visit(const luci::CircleLess *node) final { return broadcast_xy(node); }
2134 loco::NodeShape visit(const luci::CircleLessEqual *node) final { return broadcast_xy(node); }
2136 loco::NodeShape visit(const luci::CircleLocalResponseNormalization *node) final
2138 const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2139 return loco::NodeShape{input_shape};
2142 loco::NodeShape visit(const luci::CircleLog *node) final { return use_x(node); }
2144 loco::NodeShape visit(const luci::CircleLogicalAnd *node) final { return use_x(node); }
2146 loco::NodeShape visit(const luci::CircleLogicalNot *node) final { return use_x(node); }
2148 loco::NodeShape visit(const luci::CircleLogicalOr *node) final { return use_x(node); }
2150 loco::NodeShape visit(const luci::CircleLogistic *node) final { return use_x(node); }
2152 loco::NodeShape visit(const luci::CircleLogSoftmax *node) final { return use_logits(node); }
2154 loco::NodeShape visit(const luci::CircleMatrixDiag *node) final
2156 return infer_matrix_diag(node);
2159 loco::NodeShape visit(const luci::CircleMatrixSetDiag *node) final
2161 return infer_matrix_set_diag(node);
2164 loco::NodeShape visit(const luci::CircleMaximum *node) final { return broadcast_xy(node); }
2166 loco::NodeShape visit(const luci::CircleMaxPool2D *node) final
2168 return infer_pool_2d_shape(node);
2171 loco::NodeShape visit(const luci::CircleMean *node) final
2173 auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2174 return loco::NodeShape{output_shape};
2177 loco::NodeShape visit(const luci::CircleMinimum *node) final { return broadcast_xy(node); }
2179 loco::NodeShape visit(const luci::CircleMirrorPad *node) final { return infer_mirror_pad(node); }
2181 loco::NodeShape visit(const luci::CircleMul *node) final { return broadcast_xy(node); }
2183 loco::NodeShape visit(const luci::CircleNeg *node) final { return use_x(node); }
2185 loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4 *node) final
2187 const auto boxes_shape = loco::shape_get(node->boxes()).as<loco::TensorShape>();
2188 return loco::NodeShape{boxes_shape};
2191 loco::NodeShape visit(const luci::CircleNonMaxSuppressionV5 *node) final
2193 const auto boxes_shape = loco::shape_get(node->boxes()).as<loco::TensorShape>();
2194 return loco::NodeShape{boxes_shape};
2197 loco::NodeShape visit(const luci::CircleNotEqual *node) final { return broadcast_xy(node); }
2199 loco::NodeShape visit(const luci::CircleOneHot *node) final { return infer_one_hot(node); }
2201 loco::NodeShape visit(const luci::CirclePack *node) final { return infer_pack(node); }
2203 loco::NodeShape visit(const luci::CirclePad *node) final { return infer_pad(node); }
2205 loco::NodeShape visit(const luci::CirclePadV2 *node) final { return infer_pad_v2(node); }
2207 loco::NodeShape visit(const luci::CirclePow *node) final { return broadcast_xy(node); }
2209 loco::NodeShape visit(const luci::CirclePRelu *node) final { return infer_p_relu(node); }
2211 loco::NodeShape visit(const luci::CircleRange *node) final { return infer_range(node); }
2213 loco::NodeShape visit(const luci::CircleRank *) final
2215 loco::TensorShape shape_output;
2216 shape_output.rank(0);
2218 return loco::NodeShape{shape_output};
2221 loco::NodeShape visit(const luci::CircleReduceAny *node) final
2223 auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2224 return loco::NodeShape{output_shape};
2227 loco::NodeShape visit(const luci::CircleReduceMax *node) final
2229 auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2230 return loco::NodeShape{output_shape};
2233 loco::NodeShape visit(const luci::CircleReduceMin *node) final
2235 auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2236 return loco::NodeShape{output_shape};
2239 loco::NodeShape visit(const luci::CircleReduceProd *node) final
2241 auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2242 return loco::NodeShape{output_shape};
2245 loco::NodeShape visit(const luci::CircleRelu *node) final
2247 auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
2249 return loco::NodeShape{input_shape};
2252 loco::NodeShape visit(const luci::CircleRelu6 *node) final
2254 auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
2256 return loco::NodeShape{input_shape};
2259 loco::NodeShape visit(const luci::CircleReluN1To1 *node) final
2261 auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
2263 return loco::NodeShape{input_shape};
2267 * @note CircleReshape has new shape info in two places: 2nd input and attribute.
2268 * This shape inference uses shape from input 'shape' node when it's constant.
2269 * If not, shape will be from node itself. shape from attribute is not used.
2271 * TODO Change this policy when not appropriate
2273 loco::NodeShape visit(const luci::CircleReshape *node) final { return infer_reshape(node); }
2275 loco::NodeShape visit(const luci::CircleResizeBilinear *node) final
2277 return infer_resize_bilinear(node);
2280 loco::NodeShape visit(const luci::CircleResizeNearestNeighbor *node) final
2282 return infer_resize_nearest_neighbor(node);
2285 loco::NodeShape visit(const luci::CircleReverseSequence *node) final
2287 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2289 return loco::NodeShape{input_shape};
2292 loco::NodeShape visit(const luci::CircleRound *node) final { return use_x(node); }
2294 loco::NodeShape visit(const luci::CircleReverseV2 *node) final
2296 auto input_shape = loco::shape_get(node->tensor()).as<loco::TensorShape>();
2298 LUCI_ASSERT(loco::shape_get(node->axis()).as<loco::TensorShape>().rank() == 1,
2299 "Tensor must be 1-D");
2301 return loco::NodeShape{input_shape};
2304 loco::NodeShape visit(const luci::CircleRsqrt *node) final { return use_x(node); }
2306 loco::NodeShape visit(const luci::CircleScatterNd *node) final { return infer_scatter_nd(node); }
2308 loco::NodeShape visit(const luci::CircleSegmentSum *node) final
2310 return infer_segment_sum(node);
2313 loco::NodeShape visit(const luci::CircleSelect *node) final { return infer_select(node); }
2315 loco::NodeShape visit(const luci::CircleSelectV2 *node) final { return infer_select_v2(node); }
2317 loco::NodeShape visit(const luci::CircleShape *node) final { return infer_shape(node); }
2319 loco::NodeShape visit(const luci::CircleSin *node) final { return use_x(node); }
2321 loco::NodeShape visit(const luci::CircleSlice *node) final { return infer_slice(node); }
2323 loco::NodeShape visit(const luci::CircleSoftmax *node) final { return use_logits(node); }
2325 loco::NodeShape visit(const luci::CircleSpaceToBatchND *node) final
2327 return infer_space_to_batch_nd(node);
2330 loco::NodeShape visit(const luci::CircleSpaceToDepth *node) final
2332 return infer_space_to_depth(node);
2335 loco::NodeShape visit(const luci::CircleSparseToDense *node) final
2337 return infer_sparse_to_dense(node);
2340 loco::NodeShape visit(const luci::CircleSplit *node) final
2342 // We'll set Split output as same as input so that SplitOut can handle it's own shape
2343 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2344 return loco::NodeShape{input_shape};
2347 loco::NodeShape visit(const luci::CircleSplitV *node) final
2349 // We'll set SplitV output as same as input so that SplitOut can handle it's own shape
2350 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2351 return loco::NodeShape{input_shape};
2354 loco::NodeShape visit(const luci::CircleSqrt *node) final { return use_x(node); }
2356 loco::NodeShape visit(const luci::CircleSquare *node) final { return use_x(node); }
2358 loco::NodeShape visit(const luci::CircleSquaredDifference *node) final
2360 return broadcast_xy(node);
2363 loco::NodeShape visit(const luci::CircleStridedSlice *node) final
2365 return infer_strided_slice(node);
2368 loco::NodeShape visit(const luci::CircleSqueeze *node) final { return infer_squeeze(node); }
2370 loco::NodeShape visit(const luci::CircleSub *node) final { return broadcast_xy(node); }
2372 loco::NodeShape visit(const luci::CircleSum *node) final
2374 auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2375 return loco::NodeShape{output_shape};
2378 loco::NodeShape visit(const luci::CircleTanh *node) final { return use_x(node); }
2380 loco::NodeShape visit(const luci::CircleTile *node) final { return infer_tile(node); }
2382 loco::NodeShape visit(const luci::CircleTopKV2 *node) final
2384 // set shape of this node as same as input
2385 const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2386 return loco::NodeShape{input_shape};
2389 loco::NodeShape visit(const luci::CircleTranspose *node) final { return infer_transpose(node); }
2391 loco::NodeShape visit(const luci::CircleTransposeConv *node) final
2393 return infer_transpose_conv(node);
2396 loco::NodeShape visit(const luci::CircleUnpack *node) final { return infer_unpack(node); }
2398 loco::NodeShape visit(const luci::CircleUnidirectionalSequenceLSTM *node) final
2400 return infer_unidirectionalsequencelstm(node);
2403 loco::NodeShape visit(const luci::CircleUnique *node) final { return infer_unique(node); }
2405 loco::NodeShape visit(const luci::CircleWhere *node) final { return use_own(node); }
2407 loco::NodeShape visit(const luci::CircleWhile *node) final
2409 // Shape of CircleWhile is not used. Just use input 0
2410 assert(node->arity() > 0);
2411 const auto input_shape = loco::shape_get(node->input(0)).as<loco::TensorShape>();
2412 return loco::NodeShape{input_shape};
2415 loco::NodeShape visit(const luci::CircleZerosLike *node) final
2417 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2419 return loco::NodeShape{input_shape};
2423 loco::NodeShape visit(const luci::CircleBCQFullyConnected *node) final
2425 return infer_bcq_fully_connected(node);
2428 loco::NodeShape visit(const luci::CircleBCQGather *node) final { return infer_bcq_gather(node); }
2430 loco::NodeShape visit(const luci::CircleInstanceNorm *node) final
2432 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2434 return loco::NodeShape{input_shape};
2438 loco::NodeShape visit(const luci::CircleInput *node) final { return infer_input(node); }
2440 loco::NodeShape visit(const luci::CircleOutput *node) final { return infer_output(node); }
2442 loco::NodeShape visit(const luci::CircleOutputDummy *node) final { return use_own(node); }
2444 loco::NodeShape visit(const luci::CircleOutputExclude *node) final { return use_own(node); }
2446 loco::NodeShape visit(const luci::CircleCustomOut *node) final { return use_own(node); }
2448 loco::NodeShape visit(const luci::CircleIfOut *node) final { return infer_if_out(node); }
2450 loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final
2452 return infer_non_max_suppression_v4_out(node);
2455 loco::NodeShape visit(const luci::CircleNonMaxSuppressionV5Out *node) final
2457 return infer_non_max_suppression_v5_out(node);
2460 loco::NodeShape visit(const luci::CircleSplitOut *node) final { return infer_split_out(node); }
2462 loco::NodeShape visit(const luci::CircleSplitVOut *node) final { return infer_split_v_out(node); }
2464 loco::NodeShape visit(const luci::CircleTopKV2Out *node) final
2466 return infer_top_k_v2_out(node);
2469 loco::NodeShape visit(const luci::CircleUniqueOut *node) final { return infer_unique_out(node); }
2471 loco::NodeShape visit(const luci::CircleUnpackOut *node) final { return infer_unpack_out(node); }
2473 loco::NodeShape visit(const luci::CircleWhileOut *node) final { return infer_while_out(node); }
2481 bool CircleShapeInferenceRule::recognize(const loco::Dialect *d) const
2483 return CircleDialect::get() == d;
2486 bool CircleShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
2490 assert(node->dialect() == CircleDialect::get());
2492 ShapeInferenceAlgorithm alg;
2493 auto circle_node = loco::must_cast<const CircleNode *>(node);
2495 bool is_shape_undefined = (circle_node->shape_status() == ShapeStatus::UNDEFINED);
2496 bool is_shape_none = (circle_node->shape_status() == ShapeStatus::NOSHAPE);
2497 bool is_scalar = (circle_node->rank() == 0);
2499 if (is_shape_undefined)
2500 shape = circle_node->accept(&alg);
2503 if (is_shape_none || is_scalar)
2504 shape = own_shape(circle_node);
2506 shape = circle_node->accept(&alg);
2509 VERBOSE(l, 1) << "[luci] shape: " << circle_node->name();
2510 VERBOSE(l, 1) << " own_shape: " << own_shape(circle_node)
2511 << " -> infer: " << shape.as<loco::TensorShape>();