2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "luci/Service/CircleShapeInferenceRule.h"
21 #include "CircleShapeInferenceHelper.h"
22 #include "ShapeInfer_StridedSlice.h"
24 #include <luci/IR/CircleNodes.h>
25 #include <luci/IR/CircleDialect.h>
26 #include <luci/IR/CircleNodeVisitor.h>
29 #include <oops/InternalExn.h>
39 std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape)
42 for (uint32_t r = 0; r < tensor_shape.rank(); ++r)
47 if (tensor_shape.dim(r).known())
48 os << tensor_shape.dim(r).value();
56 loco::TensorShape own_shape(const luci::CircleNode *node)
58 loco::TensorShape shape;
59 shape.rank(node->rank());
60 for (uint32_t r = 0; r < node->rank(); ++r)
62 // Shape inference rules in this file did not consider unknown dimension.
63 // If some node has unknown dimension, 0 is inserted and wrong shape
64 // inference was done as a result.
65 // To fix this, new shape inference algorithm is being implemented.
66 // Until new inference algorithm is fully implemented, unknown dimension
67 // would be represented as 1 along with TFLite expression.
68 shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1;
73 loco::NodeShape use_own(const luci::CircleNode *node)
75 loco::TensorShape shape = own_shape(node);
76 return loco::NodeShape{shape};
80 * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
84 * auto expanded_tensor_shape = expand(tensor_shape).to(N);
86 class TensorShapeExpander
89 TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
95 loco::TensorShape to(uint32_t output_rank)
97 auto const &input_shape = _shape;
98 uint32_t const input_rank = input_shape.rank();
100 assert(input_rank <= output_rank && "Cannot shrink rank");
101 uint32_t const axis_shift = output_rank - input_rank;
103 loco::TensorShape output_shape;
105 output_shape.rank(output_rank);
106 for (uint32_t axis = 0; axis < output_rank; ++axis)
108 output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
115 const loco::TensorShape _shape;
119 * @brief Expand shape x and y to same rank by align right and filling with 1
121 void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
123 auto x_rank = x.rank();
124 auto y_rank = y.rank();
126 if (x_rank == y_rank)
129 TensorShapeExpander x_exp(x);
130 TensorShapeExpander y_exp(y);
132 auto xy_rank = std::max(x_rank, y_rank);
134 x = x_rank > y_rank ? x : x_exp.to(xy_rank);
135 y = y_rank > x_rank ? y : y_exp.to(xy_rank);
139 * @brief Returns shape of expanded dimension of input x and y having same rank
141 loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
143 assert(x.rank() == y.rank());
145 auto rank = x.rank();
147 loco::TensorShape output_shape;
149 output_shape.rank(rank);
150 for (uint32_t axis = 0; axis < rank; ++axis)
152 auto x_dim = x.dim(axis).known() ? x.dim(axis).value() : 1;
153 auto y_dim = y.dim(axis).known() ? y.dim(axis).value() : 1;
155 // each dimension of x and y should be same or one must be 1 if different
156 if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
157 INTERNAL_EXN("Cannot produce expand_dimension of two shapes");
159 output_shape.dim(axis) = std::max(x_dim, y_dim);
165 loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
170 expand_rank(x_match, y_match);
172 auto output_shape = expand_dimension(x_match, y_match);
178 * @brief vector_from_constant will return int64_t vector from CircleConst node
180 template <loco::DataType T> std::vector<int64_t> vector_from_constant(luci::CircleConst *const_node)
182 std::vector<int64_t> result;
184 for (uint32_t idx = 0; idx < const_node->size<T>(); ++idx)
185 result.push_back(const_node->at<T>(idx));
190 template <class CIRCLENODE> loco::NodeShape broadcast_xy(const CIRCLENODE *node)
192 auto x_shape = luci::shape_get(node->x()).template as<loco::TensorShape>();
193 auto y_shape = luci::shape_get(node->y()).template as<loco::TensorShape>();
195 auto output_shape = broadcast_shape(x_shape, y_shape);
197 return loco::NodeShape{output_shape};
200 #define DECLARE_USE_SINGLE(NAME) \
201 template <class CIRCLENODE> loco::NodeShape use_##NAME(const CIRCLENODE *node) \
203 auto inputs_shape = luci::shape_get(node->NAME()).template as<loco::TensorShape>(); \
204 return loco::NodeShape{inputs_shape}; \
207 DECLARE_USE_SINGLE(input);
208 DECLARE_USE_SINGLE(inputs);
209 DECLARE_USE_SINGLE(x);
210 DECLARE_USE_SINGLE(logits);
212 #undef DECLARE_USE_SINGLE
214 template <class CIRCLENODE>
215 loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *paddings)
217 const loco::DataType S32 = loco::DataType::S32;
219 auto input_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
221 // TODO support other data type
222 LUCI_ASSERT(paddings->dtype() == S32, "Only support int 32 for now");
223 LUCI_ASSERT(paddings->rank() == 2, "paddings should be rank 2")
225 int32_t n = paddings->dim(0).value();
226 int32_t v = paddings->dim(1).value();
228 LUCI_ASSERT(v == 2, "paddings should be [n, 2]");
229 LUCI_ASSERT(n == int32_t(input_shape.rank()),
230 "paddings [n, 2] should have same value of input rank");
232 loco::TensorShape output_shape;
234 output_shape.rank(input_shape.rank());
235 for (int32_t ni = 0; ni < n; ++ni)
237 int32_t idx = ni * 2;
238 int value = input_shape.dim(ni).value();
239 value += paddings->at<S32>(idx + 0); // left
240 value += paddings->at<S32>(idx + 1); // right
241 output_shape.dim(ni) = value;
244 return loco::NodeShape{output_shape};
247 loco::NodeShape infer_add_n(const luci::CircleAddN *node)
249 auto shape = luci::shape_get(node->inputs(0)).as<loco::TensorShape>();
251 for (uint32_t idx = 1; idx < node->arity(); ++idx)
253 auto shape_idx = luci::shape_get(node->inputs(idx)).as<loco::TensorShape>();
254 if (!(shape == shape_idx))
256 INTERNAL_EXN_V("ADD_N shape not same as the first input: ", idx);
259 return loco::NodeShape{shape};
262 template <class CIRCLENODE> loco::NodeShape infer_arg_maxmin(const CIRCLENODE *node)
264 auto input_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
265 auto dimension_shape = luci::shape_get(node->dimension()).template as<loco::TensorShape>();
267 int64_t select_axis = 0;
269 LUCI_ASSERT(node->dimension(), "2nd input dimension() should not be nullptr");
271 // Only support node's shape() is CircleConst with S32/S64
272 // Support S32 for now.
273 auto const_shape_node = loco::must_cast<luci::CircleConst *>(node->dimension());
274 LUCI_ASSERT(const_shape_node->dtype() == loco::DataType::S32,
275 "Only support int32 CircleConst for CircleArgMax/CircleArgMin");
277 if (const_shape_node->rank() > 1)
278 INTERNAL_EXN_V("Only support rank 0/1 CircleConst",
279 oops::to_uint32(const_shape_node->rank()));
281 select_axis = const_shape_node->template scalar<loco::DataType::S32>();
284 assert(select_axis < input_shape.rank());
287 select_axis += input_shape.rank();
289 // NOTE select_axis is removed
290 loco::TensorShape shape_output;
291 uint32_t rank = input_shape.rank();
292 uint32_t shrink = static_cast<uint32_t>(select_axis);
294 shape_output.rank(rank - 1);
295 for (uint32_t r = 0, d = 0; r < rank; ++r)
299 shape_output.dim(d++) = input_shape.dim(r);
301 return loco::NodeShape{shape_output};
304 // Call this for CircleAvgPool2D and CircleMaxPool2D only
305 template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node)
307 auto ifm_shape = luci::shape_get(node->value()).template as<loco::TensorShape>();
308 assert(ifm_shape.rank() == 4);
309 assert(ifm_shape.dim(1).known());
310 assert(ifm_shape.dim(2).known());
312 uint32_t input_height = ifm_shape.dim(1).value();
313 uint32_t input_width = ifm_shape.dim(2).value();
314 uint32_t stride_height = node->stride()->h();
315 uint32_t stride_width = node->stride()->w();
316 uint32_t window_height = node->filter()->h();
317 uint32_t window_width = node->filter()->w();
318 uint32_t dilation_height = 1; // dilation for CircleAvgPool2D and CircleMaxPool2D is 1
319 uint32_t dilation_width = 1;
320 uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
321 uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
323 uint32_t output_height = 0;
324 uint32_t output_width = 0;
326 if (node->padding() == luci::Padding::VALID)
328 LUCI_ASSERT(input_height + stride_height > effective_window_height, "Invalid shape");
329 LUCI_ASSERT(input_width + stride_width > effective_window_width, "Invalid shape");
330 output_height = (input_height + stride_height - effective_window_height) / stride_height;
331 output_width = (input_width + stride_width - effective_window_width) / stride_width;
333 else if (node->padding() == luci::Padding::SAME)
335 output_height = (input_height + stride_height - 1) / stride_height;
336 output_width = (input_width + stride_width - 1) / stride_width;
339 LUCI_ASSERT(false, "Wrong padding type");
341 loco::TensorShape ofm_shape;
343 ofm_shape.dim(0) = ifm_shape.dim(0);
344 ofm_shape.dim(1) = output_height;
345 ofm_shape.dim(2) = output_width;
346 ofm_shape.dim(3) = ifm_shape.dim(3);
348 return loco::NodeShape{ofm_shape};
351 loco::NodeShape infer_batch_to_space_nd(const luci::CircleBatchToSpaceND *node)
353 const loco::DataType S32 = loco::DataType::S32;
355 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
356 // Support only input rank is 3 and 4
357 assert(input_shape.rank() == 3 || input_shape.rank() == 4);
359 // Only support block_shape() with S32 type CircleConst for now
360 auto const_block_shape = loco::must_cast<luci::CircleConst *>(node->block_shape());
361 LUCI_ASSERT(const_block_shape->dtype() == loco::DataType::S32, "Only support int32 block_shape");
363 // Only support crops() with S32 type CircleConst for now
364 auto const_crops = loco::must_cast<luci::CircleConst *>(node->crops());
365 LUCI_ASSERT(const_crops->dtype() == loco::DataType::S32, "Only support int32 crops");
367 auto const_block_shape_shape = luci::shape_get(const_block_shape).as<loco::TensorShape>();
368 auto const_crops_shape = luci::shape_get(const_crops).as<loco::TensorShape>();
369 assert(const_block_shape_shape.rank() == 1);
370 assert(const_crops_shape.rank() == 2);
372 int32_t input_spatial_dim = input_shape.rank() - 2;
373 assert(const_block_shape_shape.dim(0) == input_spatial_dim);
374 assert(const_crops_shape.dim(0) == input_spatial_dim);
375 assert(const_crops_shape.dim(1) == 2);
377 loco::TensorShape shape_output;
379 shape_output.rank(input_shape.rank());
381 int32_t output_batch_size = input_shape.dim(0).value();
382 for (int32_t dim = 0; dim < input_spatial_dim; ++dim)
384 int dim_size = input_shape.dim(dim + 1).value() * const_block_shape->at<S32>(dim);
385 dim_size -= const_crops->at<S32>(dim * 2);
386 dim_size -= const_crops->at<S32>(dim * 2 + 1);
387 shape_output.dim(dim + 1) = dim_size;
389 assert(output_batch_size % const_block_shape->at<S32>(dim) == 0);
390 output_batch_size = output_batch_size / const_block_shape->at<S32>(dim);
392 shape_output.dim(0) = output_batch_size;
393 shape_output.dim(input_shape.rank() - 1) = input_shape.dim(input_shape.rank() - 1);
395 return loco::NodeShape{shape_output};
404 template <class Conv2DType> OutputSize infer_conv2d_type(const Conv2DType *node)
406 auto ifm_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
407 auto ker_shape = luci::shape_get(node->filter()).template as<loco::TensorShape>();
408 assert(ifm_shape.rank() == 4);
409 assert(ker_shape.rank() == 4);
410 assert(ifm_shape.dim(1).known());
411 assert(ifm_shape.dim(2).known());
412 assert(ker_shape.dim(1).known());
413 assert(ker_shape.dim(2).known());
415 uint32_t input_height = ifm_shape.dim(1).value();
416 uint32_t input_width = ifm_shape.dim(2).value();
417 uint32_t stride_height = node->stride()->h();
418 uint32_t stride_width = node->stride()->w();
419 uint32_t ker_height = ker_shape.dim(1).value();
420 uint32_t ker_width = ker_shape.dim(2).value();
421 uint32_t dilation_height = node->dilation()->h();
422 uint32_t dilation_width = node->dilation()->w();
423 uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
424 uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
426 uint32_t output_height = 0;
427 uint32_t output_width = 0;
429 if (node->padding() == luci::Padding::VALID)
431 LUCI_ASSERT(input_height + stride_height > effective_ker_height, "Invalid shape");
432 LUCI_ASSERT(input_width + stride_width > effective_ker_width, "Invalid shape");
433 output_height = (input_height + stride_height - effective_ker_height) / stride_height;
434 output_width = (input_width + stride_width - effective_ker_width) / stride_width;
436 else if (node->padding() == luci::Padding::SAME)
438 output_height = (input_height + stride_height - 1) / stride_height;
439 output_width = (input_width + stride_width - 1) / stride_width;
442 LUCI_ASSERT(false, "Wrong padding type");
444 OutputSize os{output_height, output_width};
449 // BatchMatMulV2 supports broadcasting in the batch dimensions(BatchMatMul doesn't)
450 // TODO Distinguish BatchMatMul and BatchMatMulV2
451 loco::NodeShape infer_batchmatmul_shape(const loco::TensorShape &x_shape,
452 const loco::TensorShape &y_shape, bool adj_x, bool adj_y)
454 uint32_t x_rank = x_shape.rank();
455 uint32_t y_rank = y_shape.rank();
456 assert(x_rank >= 2 && y_rank >= 2);
458 loco::TensorShape output_shape;
459 output_shape.rank(x_shape.rank());
460 // Braodcast in the batch dimension
461 if (x_rank > 2 || y_rank > 2)
463 loco::TensorShape dummy_x = x_shape;
464 loco::TensorShape dummy_y = y_shape;
465 expand_rank(dummy_x, dummy_y);
467 expand_rank(output_shape, dummy_y);
469 for (uint32_t d = 0; d < output_shape.rank() - 2; d++)
471 uint32_t max_dim = std::max(dummy_x.dim(d).value(), dummy_y.dim(d).value());
472 if (dummy_x.dim(d) == dummy_y.dim(d) ||
473 dummy_x.dim(d).value() * dummy_y.dim(d).value() == max_dim)
474 output_shape.dim(d).set(max_dim);
476 INTERNAL_EXN("BatchMatMul has wrong shape");
480 loco::Dimension x_lhs = adj_x ? x_shape.dim(x_rank - 1) : x_shape.dim(x_rank - 2);
481 loco::Dimension x_rhs = adj_x ? x_shape.dim(x_rank - 2) : x_shape.dim(x_rank - 1);
482 loco::Dimension y_lhs = adj_y ? y_shape.dim(y_rank - 1) : y_shape.dim(y_rank - 2);
483 loco::Dimension y_rhs = adj_y ? y_shape.dim(y_rank - 2) : y_shape.dim(y_rank - 1);
485 if (x_rhs.known() && y_lhs.known() && not(x_rhs == y_lhs))
486 INTERNAL_EXN("x_rhs and y_lhs should be same");
488 uint32_t out_rank = output_shape.rank();
489 output_shape.dim(out_rank - 2) = x_lhs;
490 output_shape.dim(out_rank - 1) = y_rhs;
492 return loco::NodeShape{output_shape};
495 loco::NodeShape infer_concatenation(const luci::CircleConcatenation *node)
497 // TODO Support when CircleConcatenation has 0 input
498 assert(node->numValues() > 0);
500 auto first_shape = luci::shape_get(node->values(0)).as<loco::TensorShape>();
501 auto axis = node->axis();
503 axis += first_shape.rank();
506 assert(first_shape.rank() > static_cast<uint32_t>(axis));
508 loco::TensorShape output_shape;
510 output_shape.rank(first_shape.rank());
511 for (uint32_t i = 0; i < output_shape.rank(); ++i)
512 output_shape.dim(i) = first_shape.dim(i);
514 for (uint32_t i = 1; i < node->numValues(); ++i)
516 auto input_shape = luci::shape_get(node->values(i)).as<loco::TensorShape>();
518 for (uint32_t j = 0; j < output_shape.rank(); ++j)
520 if (j == static_cast<uint32_t>(axis))
522 // If dimension is unknown, value() will return 0.
523 // This is wrong but until new inference algorithm is implemented,
524 // this code will not be modified to keep compatibility.
525 output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value();
528 assert(!output_shape.dim(j).known() || !input_shape.dim(j).known() ||
529 output_shape.dim(j) == input_shape.dim(j));
533 return loco::NodeShape{output_shape};
536 loco::NodeShape infer_conv2d(const luci::CircleConv2D *node)
540 auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
541 auto ker_shape = luci::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
543 assert(ifm_shape.rank() == 4);
544 assert(ker_shape.rank() == 4);
545 assert(ifm_shape.dim(3) == ker_shape.dim(3));
547 auto os = infer_conv2d_type(node);
549 loco::TensorShape ofm_shape;
551 ofm_shape.dim(0) = ifm_shape.dim(0);
552 ofm_shape.dim(1) = os.height;
553 ofm_shape.dim(2) = os.width;
554 ofm_shape.dim(3) = ker_shape.dim(0);
556 INFO(l) << "[luci] CircleConv2D ShapeInf ifm(" << ifm_shape.rank() << ") ker(" << ker_shape.rank()
557 << ") output(" << ofm_shape.dim(0).value() << "," << ofm_shape.dim(1).value() << ","
558 << ofm_shape.dim(2).value() << "," << ofm_shape.dim(3).value() << ") " << node->name()
561 return loco::NodeShape{ofm_shape};
564 loco::NodeShape infer_depth_to_space(const luci::CircleDepthToSpace *node)
566 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
567 LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
569 // Only data format NHWC is supported
570 // TODO need to clarify what to do with layout in this operator
571 int32_t height = input_shape.dim(1).value();
572 int32_t width = input_shape.dim(2).value();
573 int32_t depth = input_shape.dim(3).value();
575 int block_size = node->block_size();
578 INTERNAL_EXN("Block size must be >= 2");
580 if (depth % (block_size * block_size))
582 INTERNAL_EXN("The input tensor's depth must be divisible by block_size^2");
585 loco::TensorShape output_shape;
586 output_shape.rank(4);
588 output_shape.dim(0) = input_shape.dim(0).value();
589 output_shape.dim(1) = height * block_size;
590 output_shape.dim(2) = width * block_size;
591 output_shape.dim(3) = depth / (block_size * block_size);
593 return loco::NodeShape{output_shape};
596 loco::NodeShape infer_depthwise_conv2d(const luci::CircleDepthwiseConv2D *node)
598 auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
599 auto ker_shape = luci::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM
601 assert(ifm_shape.rank() == 4);
602 assert(ker_shape.rank() == 4);
603 assert(ker_shape.dim(0).value() == 1);
604 assert(ifm_shape.dim(3).value() * node->depthMultiplier() == ker_shape.dim(3).value());
606 auto os = infer_conv2d_type(node);
608 loco::TensorShape ofm_shape;
610 ofm_shape.dim(0) = ifm_shape.dim(0);
611 ofm_shape.dim(1) = os.height;
612 ofm_shape.dim(2) = os.width;
613 ofm_shape.dim(3) = ker_shape.dim(3);
615 return loco::NodeShape{ofm_shape};
618 loco::NodeShape infer_expand_dims(const luci::CircleExpandDims *node)
620 const loco::DataType S32 = loco::DataType::S32;
621 auto x_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
622 if (x_shape.rank() == 0)
624 // This maybe for unknown shape. We use shape from the node itself.
625 return use_own(node);
627 auto const_axis = loco::must_cast<luci::CircleConst *>(node->axis());
628 LUCI_ASSERT(const_axis->dtype() == S32, "Only support int32 CircleConst for axis");
629 if (const_axis->rank() != 0 && const_axis->rank() != 1)
631 INTERNAL_EXN_V("Non-scalar axis in OP", node->opnum());
633 int32_t axis = const_axis->at<S32>(0);
634 LUCI_ASSERT((axis <= static_cast<int32_t>(x_shape.rank())) &&
635 (axis >= -1 - static_cast<int32_t>(x_shape.rank())),
636 "Axis has to be between [-(D+1), D], where D is rank of input.");
637 size_t positive_axis = axis < 0 ? x_shape.rank() + axis + 1 : axis;
638 loco::TensorShape output_shape;
639 output_shape.rank(x_shape.rank() + 1);
641 for (; i < positive_axis; i++)
642 output_shape.dim(i) = x_shape.dim(i);
643 output_shape.dim(i) = loco::Dimension(1);
644 for (; i < x_shape.rank(); i++)
645 output_shape.dim(i + 1) = x_shape.dim(i);
646 return loco::NodeShape{output_shape};
649 loco::NodeShape infer_fill(const luci::CircleFill *node)
651 loco::TensorShape shape;
653 LUCI_ASSERT(node->dims(), "dims input should not be nullptr");
655 auto dims_node = dynamic_cast<luci::CircleConst *>(node->dims());
656 if (dims_node != nullptr)
658 // Only support node with S32
659 LUCI_ASSERT(dims_node->dtype() == loco::DataType::S32, "Only support int32 CircleConst");
661 if (dims_node->rank() != 1)
662 INTERNAL_EXN_V("Only support rank 1 CircleConst", oops::to_uint32(dims_node->rank()));
664 shape.rank(dims_node->dim(0).value());
666 for (uint32_t axis = 0; axis < shape.rank(); ++axis)
668 shape.dim(axis) = dims_node->at<loco::DataType::S32>(axis);
673 shape = own_shape(node);
677 return loco::NodeShape{shape};
680 loco::NodeShape infer_fully_connected(const luci::CircleFullyConnected *node)
682 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
683 auto weights_shape = luci::shape_get(node->weights()).as<loco::TensorShape>();
685 loco::TensorShape out_shape;
687 // NOTE Some recipes in some repositories are using rank 4 input for FullyConnected.
688 // Until they are all fixed, disable following assert.
689 // TODO Enable following assert after related fixes are applied
690 // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L194
691 // LUCI_ASSERT(input_shape.rank() == 2 || input_shape.rank() == 3,
692 // "Input rank of FullyConnected should be 2 or 3");
694 // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L225
695 LUCI_ASSERT(weights_shape.rank() == 2, "Weights of FullyConnected should be 2");
697 // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L353-L367
698 if (node->keep_num_dims())
700 out_shape.rank(input_shape.rank());
701 for (uint32_t i = 0; i < input_shape.rank(); ++i)
702 out_shape.dim(i) = input_shape.dim(i);
703 out_shape.dim(out_shape.rank() - 1) = weights_shape.dim(0);
707 uint32_t input_size = 1;
708 for (uint32_t i = 0; i < input_shape.rank(); i++)
710 input_size = input_size * input_shape.dim(i).value();
712 const uint32_t batch_size = input_size / weights_shape.dim(1).value();
714 out_shape.dim(0) = batch_size;
715 out_shape.dim(1) = weights_shape.dim(0);
718 return loco::NodeShape{out_shape};
721 loco::NodeShape infer_gather(const luci::CircleGather *node)
723 loco::TensorShape output_shape;
725 const auto input_shape = luci::shape_get(node->params()).as<loco::TensorShape>();
726 const auto positions_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
727 int32_t axis = node->axis();
729 // If CircleGather input has a dynamic shape, it can't inference this shape. So, it returns the
730 // shape that node already has.
731 if (input_shape.rank() == 0 || positions_shape.rank() == 0)
732 return use_own(node);
735 axis += input_shape.rank();
737 output_shape.rank(input_shape.rank() - 1 + positions_shape.rank());
738 int32_t outdim_index = 0;
739 for (int32_t i = 0; i < axis; ++i)
740 output_shape.dim(outdim_index++) = input_shape.dim(i);
741 for (uint32_t i = 0; i < positions_shape.rank(); ++i)
742 output_shape.dim(outdim_index++) = positions_shape.dim(i);
743 for (uint32_t i = axis + 1; i < input_shape.rank(); ++i)
744 output_shape.dim(outdim_index++) = input_shape.dim(i);
746 return loco::NodeShape{output_shape};
749 loco::NodeShape infer_gather_nd(const luci::CircleGatherNd *node)
751 loco::TensorShape output_shape;
753 const auto params_shape = luci::shape_get(node->params()).as<loco::TensorShape>();
754 const auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
756 const auto params_rank = params_shape.rank();
757 const auto indices_rank = indices_shape.rank();
759 // see https://www.tensorflow.org/api_docs/python/tf/gather_nd
760 // output.shape = indices.shape[:-1] + params.shape[indices.shape[-1]:]
761 // batch_dims isn't supported in tflite
763 // TODO: replace exceptions with setting shape to unknown?
765 if (!indices_shape.dim(indices_rank - 1).known())
766 INTERNAL_EXN("Last indices dimension is unknown");
768 auto indices_last_dim = indices_shape.dim(indices_rank - 1).value();
770 if (indices_last_dim > params_rank)
771 INTERNAL_EXN("Last indices dimension should be <= params rank");
773 const uint32_t output_rank = indices_rank + params_rank - indices_last_dim - 1;
775 output_shape.rank(output_rank);
777 uint32_t output_index = 0;
778 for (uint32_t i = 0; i < indices_rank - 1; ++i)
780 auto &dim = indices_shape.dim(i);
782 INTERNAL_EXN("Unknown indices dimension is unsupported");
783 output_shape.dim(output_index++).set(dim.value());
786 for (uint32_t i = indices_last_dim; i < params_rank; ++i)
788 auto &dim = params_shape.dim(i);
790 INTERNAL_EXN("Unknown params dimension is unsupported");
791 output_shape.dim(output_index++).set(dim.value());
794 return loco::NodeShape{output_shape};
797 loco::NodeShape infer_matrix_diag(const luci::CircleMatrixDiag *node)
799 loco::TensorShape output_shape;
801 auto diagonal_shape = luci::shape_get(node->diagonal()).as<loco::TensorShape>();
802 auto rank = diagonal_shape.rank();
804 output_shape.rank(rank + 1);
806 for (uint32_t i = 0; i < rank; i++)
808 output_shape.dim(i) = diagonal_shape.dim(i);
811 output_shape.dim(rank) = diagonal_shape.dim(rank - 1);
813 return loco::NodeShape{output_shape};
816 loco::NodeShape infer_matrix_set_diag(const luci::CircleMatrixSetDiag *node)
818 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
819 auto diagonal_shape = luci::shape_get(node->diagonal()).as<loco::TensorShape>();
821 auto rank = diagonal_shape.rank();
823 LUCI_ASSERT(rank == input_shape.rank() - 1, "diagonal rank = input rank - 1");
825 for (uint32_t i = 0; i < rank - 1; i++)
827 LUCI_ASSERT(diagonal_shape.dim(i) == input_shape.dim(i), "diagonal dims = input dims");
830 auto dim = std::min(input_shape.dim(rank - 1).value(), input_shape.dim(rank).value());
832 LUCI_ASSERT(dim == diagonal_shape.dim(rank - 1), "Max diag len error");
834 return loco::NodeShape{input_shape};
837 loco::TensorShape infer_reducer(const loco::Node *input, const loco::Node *indices, bool keep_dims)
839 const loco::DataType S32 = loco::DataType::S32;
841 auto input_shape = luci::shape_get(input).as<loco::TensorShape>();
842 auto reduction_indices = loco::must_cast<const luci::CircleConst *>(indices);
845 // TODO support non-const case
846 // TODO support other data type
847 LUCI_ASSERT(reduction_indices->dtype() == S32, "Only support int 32");
850 std::vector<int32_t> reduction_values;
852 for (uint32_t i = 0; i < reduction_indices->size<S32>(); ++i)
854 int32_t axis = reduction_indices->at<S32>(i);
856 axis += input_shape.rank();
857 if (not(0 <= axis and axis < static_cast<int32_t>(input_shape.rank())))
858 INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis));
859 reduction_values.push_back(axis);
862 loco::TensorShape output_shape;
866 output_shape.rank(input_shape.rank());
867 for (uint32_t i = 0; i < input_shape.rank(); ++i)
868 output_shape.dim(i) = input_shape.dim(i);
869 for (uint32_t i = 0; i < reduction_values.size(); ++i)
870 output_shape.dim(reduction_values.at(i)) = 1;
874 std::vector<bool> check_reduce(input_shape.rank(), false);
875 for (uint32_t i = 0; i < reduction_values.size(); ++i)
876 check_reduce.at(reduction_values.at(i)) = true;
878 uint32_t reduce_cnt = 0;
879 for (uint32_t i = 0; i < check_reduce.size(); ++i)
880 if (check_reduce.at(i))
883 output_shape.rank(input_shape.rank() - reduce_cnt);
884 for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
885 if (check_reduce.at(i) == false)
886 output_shape.dim(j++) = input_shape.dim(i);
892 loco::NodeShape infer_mirror_pad(const luci::CircleMirrorPad *node)
894 // TODO support non-const case
895 auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
896 return use_paddings(node, paddings);
899 loco::NodeShape infer_one_hot(const luci::CircleOneHot *node)
901 const loco::DataType S32 = loco::DataType::S32;
902 auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
903 // Only support OneHot node's depth() is CircleConst with type S32
904 // TODO support depth with other types
905 auto depth = loco::must_cast<luci::CircleConst *>(node->depth());
906 LUCI_ASSERT(depth->dtype() == S32, "Only support int32 CircleConst");
907 if (depth->rank() != 0)
908 INTERNAL_EXN_V("Only support rank 0 CircleOneHot in Depth", oops::to_uint32(depth->rank()));
909 loco::TensorShape output_shape;
910 output_shape.rank(indices_shape.rank() + 1);
911 auto axis = node->axis();
913 axis += indices_shape.rank() + 1;
914 LUCI_ASSERT(0 <= axis, "Axis is out of range");
915 LUCI_ASSERT(static_cast<uint32_t>(axis) <= indices_shape.rank(), "Axis is out of range");
917 for (uint32_t i = 0; i < output_shape.rank(); i++)
919 if (i == static_cast<uint32_t>(axis))
921 output_shape.dim(i) = depth->at<S32>(0);
925 output_shape.dim(i) = indices_shape.dim(j++);
928 return loco::NodeShape{output_shape};
931 loco::NodeShape infer_pack(const luci::CirclePack *node)
933 LUCI_ASSERT(node->values_count() > 0, "Only support one or more inputs");
935 auto first_shape = luci::shape_get(node->values(0)).as<loco::TensorShape>();
936 // Make sure all inputs have the same shape.
937 for (uint32_t i = 1; i < node->values_count(); ++i)
939 auto in_shape = luci::shape_get(node->values(i)).as<loco::TensorShape>();
940 LUCI_ASSERT(loco::NodeShape{first_shape} == loco::NodeShape{in_shape},
941 "All inputs must have the same shape");
944 // Checking shape capability for pack layer
945 // Input: tensors [D1, D2, ... Dn]
947 // Output: [D1, D2, ... , D_K-1, n, D_K+1, ... Dn]
948 auto axis = node->axis();
950 axis += first_shape.rank() + 1;
952 LUCI_ASSERT(0 <= axis, "Axis is out of range");
953 LUCI_ASSERT(static_cast<uint32_t>(axis) <= first_shape.rank(), "Axis is out of range");
955 loco::TensorShape output_shape;
956 output_shape.rank(first_shape.rank() + 1);
959 for (uint32_t i = 0; i < output_shape.rank(); ++i)
961 if (i == static_cast<uint32_t>(axis))
963 output_shape.dim(i) = node->values_count();
967 output_shape.dim(i) = first_shape.dim(j++);
971 return loco::NodeShape{output_shape};
974 loco::NodeShape infer_pad(const luci::CirclePad *node)
976 // TODO support non-const case
977 auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
978 return use_paddings(node, paddings);
981 loco::NodeShape infer_pad_v2(const luci::CirclePadV2 *node)
983 // TODO support non-const case
984 auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
987 auto node_shape = own_shape(node);
988 return loco::NodeShape{node_shape};
990 return use_paddings(node, paddings);
993 loco::NodeShape infer_p_relu(const luci::CirclePRelu *node)
995 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
996 auto alpha_shape = luci::shape_get(node->alpha()).as<loco::TensorShape>();
998 auto output_shape = broadcast_shape(input_shape, alpha_shape);
1000 return loco::NodeShape{output_shape};
1003 loco::NodeShape infer_range(const luci::CircleRange *node)
1005 loco::TensorShape output_shape;
1006 output_shape.rank(1);
1008 auto start_node = dynamic_cast<luci::CircleConst *>(node->start());
1009 auto limit_node = dynamic_cast<luci::CircleConst *>(node->limit());
1010 auto delta_node = dynamic_cast<luci::CircleConst *>(node->delta());
1012 if (start_node == nullptr || limit_node == nullptr || delta_node == nullptr)
1014 return use_own(node);
1017 double start = 0, limit = 0, delta = 0;
1019 #define GET_RANGE_PARAM(DT) \
1020 start = start_node->scalar<DT>(); \
1021 limit = limit_node->scalar<DT>(); \
1022 delta = delta_node->scalar<DT>();
1024 switch (start_node->dtype())
1026 case loco::DataType::FLOAT32:
1027 GET_RANGE_PARAM(loco::DataType::FLOAT32)
1029 case loco::DataType::S32:
1030 GET_RANGE_PARAM(loco::DataType::S32)
1033 INTERNAL_EXN("Range data type not supported");
1036 #undef GET_RANGE_PARAM
1039 INTERNAL_EXN("Delta can not be zero");
1041 output_shape.dim(0) = ceil((limit - start) / delta);
1043 return loco::NodeShape{output_shape};
1046 loco::NodeShape infer_reshape(const luci::CircleReshape *node)
1050 const loco::DataType S32 = loco::DataType::S32;
1052 loco::TensorShape shape_by_input;
1054 LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr");
1056 // Only support node's shape() is CircleConst with S32
1057 // TODO support other node with other types
1058 auto const_shape_node = dynamic_cast<luci::CircleConst *>(node->shape());
1059 if (const_shape_node != nullptr)
1061 LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst");
1063 shape_by_input.rank(const_shape_node->size<S32>());
1065 for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
1067 shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
1072 // We use shape from the node itself
1073 shape_by_input = own_shape(node);
1077 loco::TensorShape shape_by_attr;
1079 shape_by_attr.rank(node->newShape()->rank());
1081 for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
1083 shape_by_attr.dim(axis) = node->newShape()->dim(axis);
1087 if (!(shape_by_input == shape_by_attr))
1089 INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl;
1090 INFO(l) << " shape_by_input : " << shape_by_input << std::endl;
1091 INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl;
1094 loco::TensorShape output_shape = shape_by_input;
1096 // One of the dimensions can have special value -1, meaning its actual value should be inferred.
1097 const auto input_shape = luci::shape_get(node->tensor()).as<loco::TensorShape>();
1098 uint32_t input_element_count = 1;
1099 uint32_t output_element_count = 1;
1100 uint32_t unknown_dim_index = UINT32_MAX;
1101 for (uint32_t i = 0; i < input_shape.rank(); ++i)
1102 input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1);
1103 for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
1105 const uint32_t dim_value = output_shape.dim(dim_index).value();
1106 if (static_cast<int>(dim_value) == -1)
1108 LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
1109 unknown_dim_index = dim_index;
1113 output_element_count *= dim_value;
1116 if (unknown_dim_index != UINT32_MAX)
1118 output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
1121 return loco::NodeShape{output_shape};
1124 template <class CIRCLENODE> loco::NodeShape infer_resize_type(const CIRCLENODE *node)
1126 auto input_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
1128 if (input_shape.rank() != 4)
1129 INTERNAL_EXN("Expected input to have rank 4");
1131 auto *const_node = loco::must_cast<luci::CircleConst *>(node->size());
1133 if (const_node->dtype() != loco::DataType::S32)
1134 INTERNAL_EXN("Only S32 datatype is supported for size");
1136 if (const_node->rank() != 1)
1137 INTERNAL_EXN("Expected size tensor of rank 1");
1139 if (const_node->dim(0).value() != 2)
1140 INTERNAL_EXN("Expected size tensor with shape [2]");
1142 loco::TensorShape output_shape;
1143 output_shape.rank(4);
1144 output_shape.dim(0) = input_shape.dim(0);
1145 output_shape.dim(1) = const_node->template at<loco::DataType::S32>(0);
1146 output_shape.dim(2) = const_node->template at<loco::DataType::S32>(1);
1147 output_shape.dim(3) = input_shape.dim(3);
1149 return loco::NodeShape{output_shape};
1152 loco::NodeShape infer_scatter_nd(const luci::CircleScatterNd *node)
1154 loco::TensorShape output_shape;
1156 auto shape_node = loco::must_cast<luci::CircleConst *>(node->shape());
1158 const loco::DataType S32 = loco::DataType::S32;
1159 const loco::DataType S64 = loco::DataType::S64;
1161 std::vector<int64_t> vect_shape;
1163 if (shape_node->dtype() == S32)
1164 vect_shape = vector_from_constant<S32>(shape_node);
1165 else if (shape_node->dtype() == S64)
1166 vect_shape = vector_from_constant<S64>(shape_node);
1168 LUCI_ASSERT(false, "Only support int32/int64 for shape()");
1170 output_shape.rank(vect_shape.size());
1171 for (uint32_t i = 0; i < vect_shape.size(); ++i)
1172 output_shape.dim(i) = vect_shape[i];
1174 return loco::NodeShape{output_shape};
1177 loco::NodeShape infer_segment_sum(const luci::CircleSegmentSum *node)
1179 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1180 auto segment_shape = luci::shape_get(node->segment_ids()).as<loco::TensorShape>();
1182 LUCI_ASSERT(segment_shape.rank() == 1, "segment_ids must be 1-D tensor");
1183 LUCI_ASSERT(segment_shape.dim(0).value() == input_shape.dim(0).value(),
1184 "segment_ids size must be equal to the size of data's first dimension");
1186 auto ids_shape_value = loco::must_cast<luci::CircleConst *>(node->segment_ids());
1188 std::vector<int64_t> vect_ids;
1190 if (ids_shape_value->dtype() == loco::DataType::S32)
1191 vect_ids = vector_from_constant<loco::DataType::S32>(ids_shape_value);
1193 LUCI_ASSERT(std::is_sorted(vect_ids.begin(), vect_ids.end()),
1194 "segment_ids values should be sorted")
1196 loco::TensorShape output_shape;
1198 output_shape.rank(input_shape.rank());
1200 for (uint32_t i = 1; i < input_shape.rank(); ++i)
1201 output_shape.dim(i) = input_shape.dim(i);
1203 output_shape.dim(0) = vect_ids.back() + 1;
1205 return loco::NodeShape{output_shape};
1208 loco::NodeShape infer_select(const luci::CircleSelect *node)
1210 auto t_shape = luci::shape_get(node->t()).as<loco::TensorShape>();
1211 assert(t_shape == luci::shape_get(node->e()).as<loco::TensorShape>());
1213 // condition shape validation
1214 auto c_shape = luci::shape_get(node->condition()).as<loco::TensorShape>();
1215 if (c_shape.rank() != t_shape.rank())
1217 if (c_shape.rank() != 0 && c_shape.rank() != 1)
1218 INTERNAL_EXN_V("CircleSelect condition rank is not 0 nor 1: ", c_shape.rank());
1220 if (c_shape.rank() == 1)
1222 if (c_shape.dim(0).value() != t_shape.dim(0).value())
1223 INTERNAL_EXN("CircleSelect condition dim(0) should match with t.dim(0)");
1227 return loco::NodeShape{t_shape};
1230 loco::NodeShape infer_select_v2(const luci::CircleSelectV2 *node)
1232 auto c_shape = luci::shape_get(node->condition()).as<loco::TensorShape>();
1233 auto t_shape = luci::shape_get(node->t()).as<loco::TensorShape>();
1234 auto e_shape = luci::shape_get(node->e()).as<loco::TensorShape>();
1236 // validate ability to broadcast shapes to each other
1237 auto b_shape = broadcast_shape(broadcast_shape(c_shape, t_shape), e_shape);
1238 return loco::NodeShape{b_shape};
1241 loco::NodeShape infer_shape(const luci::CircleShape *node)
1243 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1245 loco::TensorShape output_shape;
1247 output_shape.rank(1);
1248 output_shape.dim(0) = input_shape.rank();
1250 return loco::NodeShape{output_shape};
1253 loco::NodeShape infer_slice(const luci::CircleSlice *node)
1255 const loco::DataType S32 = loco::DataType::S32;
1256 const loco::DataType S64 = loco::DataType::S64;
1258 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1260 auto const_begin = loco::must_cast<luci::CircleConst *>(node->begin());
1261 auto const_size = loco::must_cast<luci::CircleConst *>(node->size());
1263 loco::TensorShape output_shape;
1264 std::vector<int64_t> vect_begin; // to hold both S32/S64, we use int64_t
1265 std::vector<int64_t> vect_size;
1267 if (const_begin->dtype() == S32)
1268 vect_begin = vector_from_constant<S32>(const_begin);
1269 else if (const_begin->dtype() == S64)
1270 vect_begin = vector_from_constant<S64>(const_begin);
1272 LUCI_ASSERT(false, "Only support int32/int64 for begin()");
1274 if (const_size->dtype() == S32)
1275 vect_size = vector_from_constant<S32>(const_size);
1276 else if (const_size->dtype() == S64)
1277 vect_size = vector_from_constant<S64>(const_size);
1279 LUCI_ASSERT(false, "Only support int32/int64 for size()");
1281 assert(input_shape.rank() == vect_begin.size());
1282 assert(input_shape.rank() == vect_size.size());
1284 output_shape.rank(vect_begin.size());
1285 for (uint32_t idx = 0; idx < vect_begin.size(); ++idx)
1287 auto size = vect_size.at(idx);
1290 size = input_shape.dim(idx).value() - vect_begin.at(idx);
1292 output_shape.dim(idx) = size;
1295 return loco::NodeShape{output_shape};
1298 loco::NodeShape infer_space_to_batch_nd(const luci::CircleSpaceToBatchND *node)
1300 const loco::DataType S32 = loco::DataType::S32;
1302 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1303 // Support only input rank is 3 and 4
1304 assert(input_shape.rank() == 3 || input_shape.rank() == 4);
1306 // Only support block_shape() with S32 type CircleConst for now
1307 auto const_block_shape = loco::must_cast<luci::CircleConst *>(node->block_shape());
1308 LUCI_ASSERT(const_block_shape->dtype() == S32, "Only support int32 block_shape");
1310 // Only support paddings() with S32 type CircleConst for now
1311 auto const_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1312 LUCI_ASSERT(const_paddings->dtype() == S32, "Only support int32 paddings");
1314 auto const_block_shape_shape = luci::shape_get(const_block_shape).as<loco::TensorShape>();
1315 auto const_paddings_shape = luci::shape_get(const_paddings).as<loco::TensorShape>();
1316 assert(const_block_shape_shape.rank() == 1);
1317 assert(const_paddings_shape.rank() == 2);
1319 int32_t input_spatial_dim = input_shape.rank() - 2;
1320 assert(const_block_shape_shape.dim(0) == input_spatial_dim);
1321 assert(const_paddings_shape.dim(0) == input_spatial_dim);
1322 assert(const_paddings_shape.dim(1) == 2);
1324 // Check all values of block_shape >= 1
1325 uint32_t ele_count = const_block_shape->size<S32>();
1326 for (uint32_t e = 0; e < ele_count; ++e)
1328 auto val = const_block_shape->at<S32>(e);
1331 INTERNAL_EXN_V("All values of block_shape >= 1: ", e);
1335 loco::TensorShape shape_output;
1337 shape_output.rank(input_shape.rank());
1339 int32_t output_batch_size = input_shape.dim(0).value();
1340 for (int32_t dim = 0; dim < input_spatial_dim; ++dim)
1342 int dim_size = input_shape.dim(dim + 1).value();
1343 dim_size += const_paddings->at<S32>(dim * 2);
1344 dim_size += const_paddings->at<S32>(dim * 2 + 1);
1345 shape_output.dim(dim + 1) = dim_size / const_block_shape->at<S32>(dim);
1347 assert(dim_size % const_block_shape->at<S32>(dim) == 0);
1348 output_batch_size = output_batch_size * const_block_shape->at<S32>(dim);
1350 shape_output.dim(0) = output_batch_size;
1351 shape_output.dim(input_shape.rank() - 1) = input_shape.dim(input_shape.rank() - 1);
1353 return loco::NodeShape{shape_output};
1356 loco::NodeShape infer_space_to_depth(const luci::CircleSpaceToDepth *node)
1358 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1359 LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
1361 // Only data format NHWC is supported
1362 int32_t height = input_shape.dim(1).value();
1363 int32_t width = input_shape.dim(2).value();
1364 int32_t depth = input_shape.dim(3).value();
1366 int block_size = node->block_size();
1369 INTERNAL_EXN("Block size must be >= 2");
1371 if ((height % block_size) || (width % block_size))
1373 INTERNAL_EXN("The input tensor's height and width must be divisible by block_size");
1376 loco::TensorShape output_shape;
1377 output_shape.rank(4);
1379 output_shape.dim(0) = input_shape.dim(0).value();
1380 output_shape.dim(1) = height / block_size;
1381 output_shape.dim(2) = width / block_size;
1382 output_shape.dim(3) = block_size * block_size * depth;
1384 return loco::NodeShape{output_shape};
1387 loco::NodeShape infer_sparse_to_dense(const luci::CircleSparseToDense *node)
1389 loco::TensorShape shape;
1391 LUCI_ASSERT(node->output_shape(), "dims input should not be nullptr");
1393 auto output_shape_node = dynamic_cast<luci::CircleConst *>(node->output_shape());
1394 if (output_shape_node != nullptr)
1396 const auto output_shape_type = output_shape_node->dtype();
1398 if (output_shape_node->rank() != 1)
1399 INTERNAL_EXN_V("Only support rank 1 CircleConst",
1400 oops::to_uint32(output_shape_node->rank()));
1402 if (output_shape_type == loco::DataType::S32)
1404 shape.rank(output_shape_node->size<loco::DataType::S32>());
1406 for (uint32_t axis = 0; axis < shape.rank(); ++axis)
1408 shape.dim(axis) = output_shape_node->at<loco::DataType::S32>(axis);
1411 else if (output_shape_type == loco::DataType::S64)
1413 shape.rank(output_shape_node->size<loco::DataType::S64>());
1415 for (uint32_t axis = 0; axis < shape.rank(); ++axis)
1417 shape.dim(axis) = output_shape_node->at<loco::DataType::S64>(axis);
1422 INTERNAL_EXN("Output shape of SparseToDense must be either int32 or int64");
1427 shape = own_shape(node);
1431 return loco::NodeShape{shape};
1434 loco::NodeShape infer_strided_slice(const luci::CircleStridedSlice *node)
1436 auto begin_node = dynamic_cast<luci::CircleConst *>(node->begin());
1437 auto end_node = dynamic_cast<luci::CircleConst *>(node->end());
1438 auto strides_node = dynamic_cast<luci::CircleConst *>(node->strides());
1440 if (begin_node == nullptr || end_node == nullptr || strides_node == nullptr)
1442 return use_own(node);
1445 loco::TensorShape shape = infer_output_shape(node);
1446 return loco::NodeShape{shape};
1449 loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node)
1451 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1453 // TODO input shape may be unknown before runtime
1454 std::vector<bool> do_squeeze(input_shape.rank(), false);
1455 uint32_t num_squeezed = 0;
1457 if (!node->squeeze_dims().empty())
1459 // SqueezeDims not empty, squeeze only dims specified
1460 for (int32_t raw_dim : node->squeeze_dims())
1462 int32_t dim = raw_dim < 0 ? raw_dim + input_shape.rank() : raw_dim;
1464 if (dim < 0 || static_cast<uint32_t>(dim) >= input_shape.rank() ||
1465 input_shape.dim(dim).value() != 1)
1467 INTERNAL_EXN("invalid dimention specified to Squeeze");
1470 if (!do_squeeze[dim])
1472 do_squeeze[dim] = true;
1477 // SqueezeDims empty, squeeze any dims with size == 1
1478 for (uint32_t dim = 0; dim < input_shape.rank(); ++dim)
1480 if (input_shape.dim(dim) == 1)
1482 do_squeeze[dim] = true;
1488 loco::TensorShape output_shape;
1489 output_shape.rank(input_shape.rank() - num_squeezed);
1491 for (uint32_t in_dim = 0, out_dim = 0; in_dim < input_shape.rank(); ++in_dim)
1493 if (!do_squeeze[in_dim])
1495 output_shape.dim(out_dim++) = input_shape.dim(in_dim);
1499 return loco::NodeShape{output_shape};
1502 loco::NodeShape infer_svdf(const luci::CircleSVDF *node)
1504 const auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1505 const auto weight_feature_shape = luci::shape_get(node->weight_feature()).as<loco::TensorShape>();
1507 assert(ifm_shape.rank() == 2);
1508 assert(weight_feature_shape.rank() == 2);
1510 assert(ifm_shape.dim(1) == weight_feature_shape.dim(1));
1511 assert(weight_feature_shape.dim(0).known());
1513 const auto rank = node->svdf_rank();
1514 const auto num_filters = weight_feature_shape.dim(0).value();
1515 assert(num_filters % rank == 0);
1516 const auto num_units = num_filters / rank;
1518 loco::TensorShape ofm_shape;
1520 ofm_shape.dim(0) = ifm_shape.dim(0);
1521 ofm_shape.dim(1) = num_units;
1523 return loco::NodeShape{ofm_shape};
1526 loco::NodeShape infer_tile(const luci::CircleTile *node)
1528 const loco::DataType S32 = loco::DataType::S32;
1530 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1531 auto multiples = loco::must_cast<luci::CircleConst *>(node->multiples());
1533 // TODO support non-const case
1534 // TODO support S64 type
1535 LUCI_ASSERT(multiples->dtype() == S32, "Only support int32 multiples");
1536 LUCI_ASSERT(multiples->rank() == 1, "multiples should be rank 1")
1538 uint32_t n = multiples->dim(0).value();
1540 LUCI_ASSERT(n == input_shape.rank(), "length of multiples should be the same with input rank");
1542 loco::TensorShape output_shape;
1544 output_shape.rank(input_shape.rank());
1545 for (uint32_t ni = 0; ni < n; ++ni)
1547 int32_t multiple = multiples->at<S32>(ni);
1548 output_shape.dim(ni) = input_shape.dim(ni).value() * static_cast<uint32_t>(multiple);
1551 return loco::NodeShape{output_shape};
1554 loco::NodeShape infer_transpose(const luci::CircleTranspose *node)
1556 auto input_shape = luci::shape_get(node->a()).as<loco::TensorShape>();
1558 auto perm_node = loco::must_cast<luci::CircleConst *>(node->perm());
1560 loco::TensorShape output_shape;
1561 output_shape.rank(input_shape.rank());
1563 assert(perm_node->dtype() == loco::DataType::S32);
1564 assert(input_shape.rank() == perm_node->template size<loco::DataType::S32>());
1566 for (uint32_t out_axis = 0; out_axis < output_shape.rank(); out_axis++)
1568 auto in_axis = perm_node->template at<loco::DataType::S32>(out_axis);
1569 output_shape.dim(out_axis) = input_shape.dim(in_axis);
1572 return output_shape;
1575 loco::NodeShape infer_transpose_conv(const luci::CircleTransposeConv *node)
1577 // TransposeConv's output shape is written in its 'inputSizes' argument
1578 auto input_sizes_const = loco::must_cast<luci::CircleConst *>(node->inputSizes());
1579 // TODO support non-const type
1580 LUCI_ASSERT(input_sizes_const->dtype() == loco::DataType::S32, "Only support S32 dtype")
1581 LUCI_ASSERT(input_sizes_const->rank() == 1 && input_sizes_const->dim(0).value() == 4,
1582 "Only support rank 1 with 4 entries")
1584 loco::TensorShape shape;
1587 for (uint32_t axis = 0; axis < 4; ++axis)
1588 shape.dim(axis) = input_sizes_const->at<loco::DataType::S32>(axis);
1590 return loco::NodeShape{shape};
1593 loco::NodeShape infer_unpack(const luci::CircleUnpack *node)
1595 // CircleUnpack provides list(array) of Tensors which has one less dimension of the input
1596 // We'll set shape of CircleUnpack to shape of actual outputs
1597 // TODO fix this if any problem rises
1598 auto value_shape = luci::shape_get(node->value()).as<loco::TensorShape>();
1600 auto axis = node->axis();
1601 auto num = node->num();
1602 auto rank = static_cast<int32_t>(value_shape.rank());
1607 return use_own(node);
1610 LUCI_ASSERT(-rank <= axis && axis < rank, "Axis is out of range");
1615 LUCI_ASSERT(num == static_cast<int32_t>(value_shape.dim(axis).value()),
1616 "num, axis maybe incorrect");
1618 loco::TensorShape output_shape;
1619 output_shape.rank(rank - 1);
1621 for (int32_t i = 0, o = 0; i < rank; ++i)
1624 output_shape.dim(o++) = value_shape.dim(i);
1627 return loco::NodeShape{output_shape};
1630 loco::NodeShape infer_unidirectionalsequencelstm(const luci::CircleUnidirectionalSequenceLSTM *node)
1632 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1633 auto recurrent_to_output_weights =
1634 luci::shape_get(node->recurrent_to_output_weights()).as<loco::TensorShape>();
1635 auto rank = input_shape.rank();
1636 loco::TensorShape output_shape;
1637 output_shape.rank(rank);
1638 for (uint32_t i = 0; i < rank - 1; i++)
1640 output_shape.dim(i) = input_shape.dim(i);
1642 output_shape.dim(rank - 1) = recurrent_to_output_weights.dim(1);
1643 return loco::NodeShape{output_shape};
1646 loco::NodeShape infer_unique(const luci::CircleUnique *node)
1648 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1650 assert(input_shape.rank() == 1);
1652 loco::TensorShape shape_output;
1653 shape_output = own_shape(node);
1655 return loco::NodeShape{shape_output};
1659 loco::NodeShape infer_bcq_fully_connected(const luci::CircleBCQFullyConnected *node)
1661 loco::TensorShape out_shape;
1663 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1664 auto weights_clusters = loco::must_cast<luci::CircleConst *>(node->weights_clusters());
1666 LUCI_ASSERT(input_shape.rank() == 2, "Input rank of BCQFullyConnected should be 2");
1668 int32_t qbits_sum = 0;
1669 for (uint32_t i = 0; i < weights_clusters->dim(0).value(); ++i)
1671 qbits_sum += weights_clusters->at<loco::DataType::S32>(i * 2 + 1);
1675 out_shape.dim(0) = qbits_sum;
1676 out_shape.dim(1) = input_shape.dim(1);
1678 return loco::NodeShape{out_shape};
1681 loco::NodeShape infer_bcq_gather(const luci::CircleBCQGather *node)
1683 loco::TensorShape input_shape;
1684 loco::TensorShape output_shape;
1686 const auto input_binary_shape = luci::shape_get(node->input_binary()).as<loco::TensorShape>();
1687 const auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
1688 auto axis = node->axis();
1690 auto input_clusters = loco::must_cast<luci::CircleConst *>(node->input_clusters());
1692 for (uint32_t i = 0; i < input_clusters->dim(0).value(); ++i)
1694 qbits_sum += input_clusters->at<loco::DataType::S32>(i * 2 + 1);
1697 input_shape.rank(2);
1698 input_shape.dim(0) = qbits_sum;
1699 input_shape.dim(1) = input_binary_shape.dim(1).value() * 32;
1701 output_shape.rank(input_shape.rank() - 1 + indices_shape.rank());
1702 int32_t outdim_index = 0;
1703 for (int32_t i = 0; i < axis; ++i)
1704 output_shape.dim(outdim_index++) = input_shape.dim(i);
1705 for (uint32_t i = 0; i < indices_shape.rank(); ++i)
1706 output_shape.dim(outdim_index++) = indices_shape.dim(i);
1707 for (uint32_t i = axis + 1; i < input_shape.rank(); ++i)
1708 output_shape.dim(outdim_index++) = input_shape.dim(i);
1710 return loco::NodeShape{output_shape};
1714 loco::NodeShape infer_input(const luci::CircleInput *node)
1716 loco::TensorShape shape;
1718 shape.rank(node->rank());
1719 for (uint32_t axis = 0; axis < node->rank(); axis++)
1720 shape.dim(axis) = node->dim(axis);
1722 return loco::NodeShape{shape};
1725 loco::NodeShape infer_output(const luci::CircleOutput *node)
1727 auto graph_outputs = node->graph()->outputs();
1728 auto graph_output = graph_outputs->at(node->index());
1729 auto output_shape = graph_output->shape();
1731 return loco::NodeShape{*output_shape};
1734 loco::NodeShape infer_non_max_suppression_v4_out(const luci::CircleNonMaxSuppressionV4Out *node)
1736 const loco::DataType S32 = loco::DataType::S32;
1738 auto nmsv4 = dynamic_cast<const luci::CircleNonMaxSuppressionV4 *>(node->input());
1739 if (nmsv4 == nullptr)
1740 INTERNAL_EXN("CircleNonMaxSuppressionV4 IR is not configured correctly");
1742 auto index = node->index();
1744 return loco::TensorShape({0});
1748 auto unknown = loco::TensorShape{loco::Dimension()};
1749 auto max_output_size = dynamic_cast<const luci::CircleConst *>(nmsv4->max_output_size());
1750 if (max_output_size == nullptr)
1751 return unknown; // we need CircleConst for max output size
1753 LUCI_ASSERT(max_output_size->dtype() == S32, "Only support int32 for max_output_size");
1755 if (max_output_size->size<S32>() < 1)
1758 auto max_output_size_value = uint32_t(max_output_size->at<S32>(0));
1759 return loco::TensorShape{max_output_size_value};
1762 loco::NodeShape infer_non_max_suppression_v5_out(const luci::CircleNonMaxSuppressionV5Out *node)
1764 const loco::DataType S32 = loco::DataType::S32;
1766 auto nmsv5 = dynamic_cast<const luci::CircleNonMaxSuppressionV5 *>(node->input());
1767 if (nmsv5 == nullptr)
1768 INTERNAL_EXN("CircleNonMaxSuppressionV5 IR is not configured correctly");
1770 auto index = node->index();
1772 return loco::TensorShape({0});
1774 assert(index == 0 || index == 1);
1776 auto unknown = loco::TensorShape{loco::Dimension()};
1777 auto max_output_size = dynamic_cast<const luci::CircleConst *>(nmsv5->max_output_size());
1778 if (max_output_size == nullptr)
1779 return unknown; // we need CircleConst for max output size
1781 LUCI_ASSERT(max_output_size->dtype() == S32, "Only support int32 for max_output_size");
1783 if (max_output_size->size<S32>() < 1)
1786 auto max_output_size_value = uint32_t(max_output_size->at<S32>(0));
1787 return loco::TensorShape{max_output_size_value};
1790 loco::NodeShape infer_split_out(const luci::CircleSplitOut *node)
1792 const loco::DataType S32 = loco::DataType::S32;
1794 auto split = dynamic_cast<const luci::CircleSplit *>(node->input());
1795 if (split == nullptr)
1796 INTERNAL_EXN("CircleSplit IR is not configured correctly");
1798 loco::NodeShape unknown;
1800 auto split_shape = luci::shape_get(split).as<loco::TensorShape>();
1802 auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
1803 if (split_dim == nullptr)
1804 return unknown; // we need CircleConst for split_dim
1805 LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
1807 assert(split_dim->size<S32>() == 1);
1808 auto split_dim_axis = split_dim->at<S32>(0);
1809 if (split_dim_axis < 0)
1810 split_dim_axis += split_shape.rank();
1812 auto split_dim_value = split_shape.dim(split_dim_axis).value();
1813 assert(split_dim_value % split->num_split() == 0);
1814 const int split_depth = split_dim_value / split->num_split();
1816 loco::TensorShape output_shape = split_shape;
1818 // All shapes are equally same
1819 output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
1821 return loco::NodeShape{output_shape};
1824 loco::NodeShape infer_split_v_out(const luci::CircleSplitVOut *node)
1826 const loco::DataType S32 = loco::DataType::S32;
1828 auto split = dynamic_cast<const luci::CircleSplitV *>(node->input());
1829 if (split == nullptr)
1830 INTERNAL_EXN("CircleSplit IR is not configured correctly");
1832 loco::NodeShape unknown;
1834 auto split_shape = luci::shape_get(split).as<loco::TensorShape>();
1836 auto size_splits = dynamic_cast<const luci::CircleConst *>(split->size_splits());
1837 if (size_splits == nullptr)
1838 return unknown; // we need CircleConst for size_splits
1839 LUCI_ASSERT(size_splits->dtype() == S32, "Only support int32 for size_splits");
1841 auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
1842 if (split_dim == nullptr)
1843 return unknown; // we need CircleConst for split_dim
1844 LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
1847 assert(split_dim->size<S32>() == 1);
1848 auto split_dim_axis = split_dim->at<S32>(0);
1849 if (split_dim_axis < 0)
1850 split_dim_axis += split_shape.rank();
1852 // interpret size_splits values
1853 int32_t size_splits_count = static_cast<int32_t>(size_splits->size<S32>());
1854 assert(size_splits_count == split->num_split());
1856 int64_t minus_one_count = 0, size_splits_sum = 0;
1857 for (int32_t idx = 0; idx < size_splits_count; ++idx)
1859 auto size = size_splits->at<S32>(idx);
1864 size_splits_sum += size;
1866 if (minus_one_count > 1)
1867 INTERNAL_EXN("CircleSplitV size_splits has more than two -1 values");
1869 // calcuate this SplitVOut shape
1870 auto input_size = split_shape.dim(split_dim_axis).value();
1871 assert(size_splits_sum <= input_size);
1873 auto index_this = node->index();
1874 assert(0 <= index_this && index_this < split->num_split());
1875 auto split_depth = size_splits->at<S32>(index_this);
1876 if (split_depth == -1)
1877 split_depth = input_size - size_splits_sum;
1879 loco::TensorShape output_shape = split_shape;
1881 output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
1883 return loco::NodeShape{output_shape};
1886 loco::NodeShape infer_top_k_v2_out(const luci::CircleTopKV2Out *node)
1888 const loco::DataType S32 = loco::DataType::S32;
1890 auto topkv2 = dynamic_cast<const luci::CircleTopKV2 *>(node->input());
1891 if (topkv2 == nullptr)
1892 INTERNAL_EXN("CircleSplit IR is not configured correctly");
1894 // shape of topkv2 is same as topkv2->input()
1895 auto input_shape = luci::shape_get(topkv2).as<loco::TensorShape>();
1897 auto node_k = loco::must_cast<const luci::CircleConst *>(topkv2->k());
1898 LUCI_ASSERT(node_k->dtype() == S32, "Only support Int32");
1899 assert(node_k->size<S32>() == 1);
1901 loco::TensorShape output_shape;
1903 output_shape.rank(input_shape.rank());
1904 for (uint32_t idx = 0; idx < input_shape.rank() - 1; ++idx)
1906 output_shape.dim(idx) = input_shape.dim(idx);
1908 output_shape.dim(input_shape.rank() - 1) = node_k->at<S32>(0);
1910 return loco::NodeShape{output_shape};
1913 loco::NodeShape infer_unique_out(const luci::CircleUniqueOut *node)
1915 if (node->index() == 0)
1917 auto unique_shape = own_shape(node);
1918 return loco::NodeShape{unique_shape};
1920 assert(node->index() == 1);
1921 auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
1922 auto unique_shape = luci::shape_get(unique->input()).as<loco::TensorShape>();
1924 assert(unique_shape.rank() == 1);
1926 loco::TensorShape shape_output;
1927 shape_output.rank(1);
1928 shape_output.dim(0) = unique_shape.dim(0);
1929 return loco::NodeShape{shape_output};
1932 loco::NodeShape infer_unpack_out(const luci::CircleUnpackOut *node)
1934 auto unpack = dynamic_cast<const luci::CircleUnpack *>(node->input());
1935 if (unpack == nullptr)
1937 INTERNAL_EXN("CircleUnpack IR is not configured correctly");
1940 auto unpack_shape = luci::shape_get(unpack).as<loco::TensorShape>();
1942 return loco::NodeShape{unpack_shape};
1945 loco::NodeShape infer_while_out(const luci::CircleWhileOut *node)
1948 * @note WHILE operator's shape is the same with the "cond"
1951 auto circle_while = dynamic_cast<const luci::CircleWhile *>(node->input());
1952 if (circle_while == nullptr)
1954 INTERNAL_EXN("CircleWhile IR is not configured correctly");
1957 auto index = node->index();
1958 auto cond_graph = circle_while->cond_graph();
1959 assert(cond_graph != nullptr);
1961 // Assumption: the index of CircleWhileOut matches with the index of input nodes returned by
1962 // loco::input_nodes
1963 auto cond_inputs = loco::input_nodes(cond_graph);
1964 auto cond_in = loco::must_cast<luci::CircleInput *>(cond_inputs.at(index));
1966 auto cond_graph_inputs = cond_graph->inputs();
1967 auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
1969 auto cond_graph_input_shape = *cond_graph_input->shape();
1970 auto this_shape = own_shape(node);
1972 if (!(this_shape == cond_graph_input_shape))
1975 WARN(l) << "Warning: CircleWhileOut '" << node->name() << "' shape mispatch " << this_shape
1976 << " vs " << cond_graph_input_shape;
1979 return loco::NodeShape{this_shape};
1983 * @brief Class to infer the shape of CircleNode
1985 * @note All CircleNode's inputs and outputs are always loco::Domain::Tensor
1987 class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeShape>
1990 loco::NodeShape visit(const luci::CircleAbs *node) final { return use_x(node); }
1992 loco::NodeShape visit(const luci::CircleAdd *node) final { return broadcast_xy(node); }
1994 loco::NodeShape visit(const luci::CircleAddN *node) final { return infer_add_n(node); }
1996 loco::NodeShape visit(const luci::CircleArgMax *node) final { return infer_arg_maxmin(node); }
1998 loco::NodeShape visit(const luci::CircleArgMin *node) final { return infer_arg_maxmin(node); }
2000 loco::NodeShape visit(const luci::CircleAveragePool2D *node) final
2002 return infer_pool_2d_shape(node);
2005 loco::NodeShape visit(const luci::CircleBatchMatMul *node) final
2007 auto x_shape = luci::shape_get(node->x()).as<loco::TensorShape>();
2008 auto y_shape = luci::shape_get(node->y()).as<loco::TensorShape>();
2010 return infer_batchmatmul_shape(x_shape, y_shape, node->adj_x(), node->adj_y());
2013 loco::NodeShape visit(const luci::CircleBatchToSpaceND *node) final
2015 return infer_batch_to_space_nd(node);
2018 loco::NodeShape visit(const luci::CircleCast *node) final { return use_x(node); }
2020 loco::NodeShape visit(const luci::CircleCeil *node) final { return use_x(node); }
2022 loco::NodeShape visit(const luci::CircleConcatenation *node) final
2024 return infer_concatenation(node);
2027 loco::NodeShape visit(const luci::CircleConst *node) final { return use_own(node); }
2029 loco::NodeShape visit(const luci::CircleConv2D *node) final { return infer_conv2d(node); }
2031 loco::NodeShape visit(const luci::CircleCos *node) final { return use_x(node); }
2033 loco::NodeShape visit(const luci::CircleCustom *node) final { return use_own(node); }
2035 loco::NodeShape visit(const luci::CircleDensify *node) final { return use_input(node); }
2037 loco::NodeShape visit(const luci::CircleDepthToSpace *node) final
2039 return infer_depth_to_space(node);
2042 loco::NodeShape visit(const luci::CircleDepthwiseConv2D *node) final
2044 return infer_depthwise_conv2d(node);
2047 loco::NodeShape visit(const luci::CircleDequantize *node) final
2049 const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2050 return loco::NodeShape{input_shape};
2053 loco::NodeShape visit(const luci::CircleDiv *node) final { return broadcast_xy(node); }
2055 loco::NodeShape visit(const luci::CircleElu *node) final
2057 auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
2059 return loco::NodeShape{input_shape};
2062 loco::NodeShape visit(const luci::CircleEqual *node) final { return broadcast_xy(node); }
2064 loco::NodeShape visit(const luci::CircleExp *node) final { return use_x(node); }
2066 loco::NodeShape visit(const luci::CircleExpandDims *node) final
2068 return infer_expand_dims(node);
2071 loco::NodeShape visit(const luci::CircleFakeQuant *node) final { return use_inputs(node); }
2073 loco::NodeShape visit(const luci::CircleFill *node) final { return infer_fill(node); }
2075 loco::NodeShape visit(const luci::CircleFloor *node) final { return use_x(node); }
2077 loco::NodeShape visit(const luci::CircleFloorDiv *node) final { return broadcast_xy(node); }
2079 loco::NodeShape visit(const luci::CircleFloorMod *node) final { return broadcast_xy(node); }
2081 loco::NodeShape visit(const luci::CircleFullyConnected *node) final
2083 return infer_fully_connected(node);
2086 loco::NodeShape visit(const luci::CircleGather *node) final { return infer_gather(node); }
2088 loco::NodeShape visit(const luci::CircleGatherNd *node) final { return infer_gather_nd(node); }
2090 loco::NodeShape visit(const luci::CircleGreater *node) final { return broadcast_xy(node); }
2092 loco::NodeShape visit(const luci::CircleGreaterEqual *node) final { return broadcast_xy(node); }
2094 loco::NodeShape visit(const luci::CircleIf *node) final
2096 // Shape of CircleIf is not used. Just use input 0
2097 assert(node->input_count() > 0);
2098 const auto input_shape = luci::shape_get(node->input(0)).as<loco::TensorShape>();
2099 return loco::NodeShape{input_shape};
2102 loco::NodeShape visit(const luci::CircleL2Normalize *node) final { return use_x(node); }
2104 loco::NodeShape visit(const luci::CircleL2Pool2D *node) final
2106 return infer_pool_2d_shape(node);
2109 loco::NodeShape visit(const luci::CircleLeakyRelu *node) final
2111 const auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
2112 return loco::NodeShape{input_shape};
2115 loco::NodeShape visit(const luci::CircleLess *node) final { return broadcast_xy(node); }
2117 loco::NodeShape visit(const luci::CircleLessEqual *node) final { return broadcast_xy(node); }
2119 loco::NodeShape visit(const luci::CircleLocalResponseNormalization *node) final
2121 const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2122 return loco::NodeShape{input_shape};
2125 loco::NodeShape visit(const luci::CircleLog *node) final { return use_x(node); }
2127 loco::NodeShape visit(const luci::CircleLogicalAnd *node) final { return use_x(node); }
2129 loco::NodeShape visit(const luci::CircleLogicalNot *node) final { return use_x(node); }
2131 loco::NodeShape visit(const luci::CircleLogicalOr *node) final { return use_x(node); }
2133 loco::NodeShape visit(const luci::CircleLogistic *node) final { return use_x(node); }
2135 loco::NodeShape visit(const luci::CircleLogSoftmax *node) final { return use_logits(node); }
2137 loco::NodeShape visit(const luci::CircleMatrixDiag *node) final
2139 return infer_matrix_diag(node);
2142 loco::NodeShape visit(const luci::CircleMatrixSetDiag *node) final
2144 return infer_matrix_set_diag(node);
2147 loco::NodeShape visit(const luci::CircleMaximum *node) final { return broadcast_xy(node); }
2149 loco::NodeShape visit(const luci::CircleMaxPool2D *node) final
2151 return infer_pool_2d_shape(node);
2154 loco::NodeShape visit(const luci::CircleMean *node) final
2156 auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2157 return loco::NodeShape{output_shape};
2160 loco::NodeShape visit(const luci::CircleMinimum *node) final { return broadcast_xy(node); }
2162 loco::NodeShape visit(const luci::CircleMirrorPad *node) final { return infer_mirror_pad(node); }
2164 loco::NodeShape visit(const luci::CircleMul *node) final { return broadcast_xy(node); }
2166 loco::NodeShape visit(const luci::CircleNeg *node) final { return use_x(node); }
2168 loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4 *node) final
2170 const auto boxes_shape = luci::shape_get(node->boxes()).as<loco::TensorShape>();
2171 return loco::NodeShape{boxes_shape};
2174 loco::NodeShape visit(const luci::CircleNonMaxSuppressionV5 *node) final
2176 const auto boxes_shape = luci::shape_get(node->boxes()).as<loco::TensorShape>();
2177 return loco::NodeShape{boxes_shape};
2180 loco::NodeShape visit(const luci::CircleNotEqual *node) final { return broadcast_xy(node); }
2182 loco::NodeShape visit(const luci::CircleOneHot *node) final { return infer_one_hot(node); }
2184 loco::NodeShape visit(const luci::CirclePack *node) final { return infer_pack(node); }
2186 loco::NodeShape visit(const luci::CirclePad *node) final { return infer_pad(node); }
2188 loco::NodeShape visit(const luci::CirclePadV2 *node) final { return infer_pad_v2(node); }
2190 loco::NodeShape visit(const luci::CirclePow *node) final { return broadcast_xy(node); }
2192 loco::NodeShape visit(const luci::CirclePRelu *node) final { return infer_p_relu(node); }
2194 loco::NodeShape visit(const luci::CircleQuantize *node) final
2196 const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2197 return loco::NodeShape{input_shape};
2200 loco::NodeShape visit(const luci::CircleRange *node) final { return infer_range(node); }
2202 loco::NodeShape visit(const luci::CircleRank *) final
2204 loco::TensorShape shape_output;
2205 shape_output.rank(0);
2207 return loco::NodeShape{shape_output};
2210 loco::NodeShape visit(const luci::CircleReduceAny *node) final
2212 auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2213 return loco::NodeShape{output_shape};
2216 loco::NodeShape visit(const luci::CircleReduceMax *node) final
2218 auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2219 return loco::NodeShape{output_shape};
2222 loco::NodeShape visit(const luci::CircleReduceMin *node) final
2224 auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2225 return loco::NodeShape{output_shape};
2228 loco::NodeShape visit(const luci::CircleReduceProd *node) final
2230 auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2231 return loco::NodeShape{output_shape};
2234 loco::NodeShape visit(const luci::CircleRelu *node) final
2236 auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
2238 return loco::NodeShape{input_shape};
2241 loco::NodeShape visit(const luci::CircleRelu6 *node) final
2243 auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
2245 return loco::NodeShape{input_shape};
2248 loco::NodeShape visit(const luci::CircleReluN1To1 *node) final
2250 auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
2252 return loco::NodeShape{input_shape};
2256 * @note CircleReshape has new shape info in two places: 2nd input and attribute.
2257 * This shape inference uses shape from input 'shape' node when it's constant.
2258 * If not, shape will be from node itself. shape from attribute is not used.
2260 * TODO Change this policy when not appropriate
2262 loco::NodeShape visit(const luci::CircleReshape *node) final { return infer_reshape(node); }
2264 loco::NodeShape visit(const luci::CircleResizeBilinear *node) final
2266 return infer_resize_type(node);
2269 loco::NodeShape visit(const luci::CircleResizeNearestNeighbor *node) final
2271 return infer_resize_type(node);
2274 loco::NodeShape visit(const luci::CircleReverseSequence *node) final
2276 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2278 return loco::NodeShape{input_shape};
2281 loco::NodeShape visit(const luci::CircleRound *node) final { return use_x(node); }
2283 loco::NodeShape visit(const luci::CircleReverseV2 *node) final
2285 auto input_shape = luci::shape_get(node->tensor()).as<loco::TensorShape>();
2287 LUCI_ASSERT(luci::shape_get(node->axis()).as<loco::TensorShape>().rank() == 1,
2288 "Tensor must be 1-D");
2290 return loco::NodeShape{input_shape};
2293 loco::NodeShape visit(const luci::CircleRsqrt *node) final { return use_x(node); }
2295 loco::NodeShape visit(const luci::CircleScatterNd *node) final { return infer_scatter_nd(node); }
2297 loco::NodeShape visit(const luci::CircleSegmentSum *node) final
2299 return infer_segment_sum(node);
2302 loco::NodeShape visit(const luci::CircleSelect *node) final { return infer_select(node); }
2304 loco::NodeShape visit(const luci::CircleSelectV2 *node) final { return infer_select_v2(node); }
2306 loco::NodeShape visit(const luci::CircleShape *node) final { return infer_shape(node); }
2308 loco::NodeShape visit(const luci::CircleSin *node) final { return use_x(node); }
2310 loco::NodeShape visit(const luci::CircleSlice *node) final { return infer_slice(node); }
2312 loco::NodeShape visit(const luci::CircleSoftmax *node) final { return use_logits(node); }
2314 loco::NodeShape visit(const luci::CircleSpaceToBatchND *node) final
2316 return infer_space_to_batch_nd(node);
2319 loco::NodeShape visit(const luci::CircleSpaceToDepth *node) final
2321 return infer_space_to_depth(node);
2324 loco::NodeShape visit(const luci::CircleSparseToDense *node) final
2326 return infer_sparse_to_dense(node);
2329 loco::NodeShape visit(const luci::CircleSplit *node) final
2331 // We'll set Split output as same as input so that SplitOut can handle it's own shape
2332 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2333 return loco::NodeShape{input_shape};
2336 loco::NodeShape visit(const luci::CircleSplitV *node) final
2338 // We'll set SplitV output as same as input so that SplitOut can handle it's own shape
2339 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2340 return loco::NodeShape{input_shape};
2343 loco::NodeShape visit(const luci::CircleSqrt *node) final { return use_x(node); }
2345 loco::NodeShape visit(const luci::CircleSquare *node) final { return use_x(node); }
2347 loco::NodeShape visit(const luci::CircleSquaredDifference *node) final
2349 return broadcast_xy(node);
2352 loco::NodeShape visit(const luci::CircleStridedSlice *node) final
2354 return infer_strided_slice(node);
2357 loco::NodeShape visit(const luci::CircleSqueeze *node) final { return infer_squeeze(node); }
2359 loco::NodeShape visit(const luci::CircleSub *node) final { return broadcast_xy(node); }
2361 loco::NodeShape visit(const luci::CircleSum *node) final
2363 auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2364 return loco::NodeShape{output_shape};
2367 loco::NodeShape visit(const luci::CircleSVDF *node) final { return infer_svdf(node); }
2369 loco::NodeShape visit(const luci::CircleTanh *node) final { return use_x(node); }
2371 loco::NodeShape visit(const luci::CircleTile *node) final { return infer_tile(node); }
2373 loco::NodeShape visit(const luci::CircleTopKV2 *node) final
2375 // set shape of this node as same as input
2376 const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2377 return loco::NodeShape{input_shape};
2380 loco::NodeShape visit(const luci::CircleTranspose *node) final { return infer_transpose(node); }
2382 loco::NodeShape visit(const luci::CircleTransposeConv *node) final
2384 return infer_transpose_conv(node);
2387 loco::NodeShape visit(const luci::CircleUnpack *node) final { return infer_unpack(node); }
2389 loco::NodeShape visit(const luci::CircleUnidirectionalSequenceLSTM *node) final
2391 return infer_unidirectionalsequencelstm(node);
2394 loco::NodeShape visit(const luci::CircleUnique *node) final { return infer_unique(node); }
2396 loco::NodeShape visit(const luci::CircleWhere *node) final { return use_own(node); }
2398 loco::NodeShape visit(const luci::CircleWhile *node) final
2400 // Shape of CircleWhile is not used. Just use input 0
2401 assert(node->arity() > 0);
2402 const auto input_shape = luci::shape_get(node->input(0)).as<loco::TensorShape>();
2403 return loco::NodeShape{input_shape};
2406 loco::NodeShape visit(const luci::CircleZerosLike *node) final
2408 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2410 return loco::NodeShape{input_shape};
2414 loco::NodeShape visit(const luci::CircleBCQFullyConnected *node) final
2416 return infer_bcq_fully_connected(node);
2419 loco::NodeShape visit(const luci::CircleBCQGather *node) final { return infer_bcq_gather(node); }
2421 loco::NodeShape visit(const luci::CircleInstanceNorm *node) final
2423 auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2425 return loco::NodeShape{input_shape};
2429 loco::NodeShape visit(const luci::CircleInput *node) final { return infer_input(node); }
2431 loco::NodeShape visit(const luci::CircleOutput *node) final { return infer_output(node); }
2433 loco::NodeShape visit(const luci::CircleOutputDummy *node) final { return use_own(node); }
2435 loco::NodeShape visit(const luci::CircleOutputExclude *node) final { return use_own(node); }
2437 loco::NodeShape visit(const luci::CircleCustomOut *node) final { return use_own(node); }
2439 loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final
2441 return infer_non_max_suppression_v4_out(node);
2444 loco::NodeShape visit(const luci::CircleNonMaxSuppressionV5Out *node) final
2446 return infer_non_max_suppression_v5_out(node);
2449 loco::NodeShape visit(const luci::CircleSplitOut *node) final { return infer_split_out(node); }
2451 loco::NodeShape visit(const luci::CircleSplitVOut *node) final { return infer_split_v_out(node); }
2453 loco::NodeShape visit(const luci::CircleTopKV2Out *node) final
2455 return infer_top_k_v2_out(node);
2458 loco::NodeShape visit(const luci::CircleUniqueOut *node) final { return infer_unique_out(node); }
2460 loco::NodeShape visit(const luci::CircleUnpackOut *node) final { return infer_unpack_out(node); }
2462 loco::NodeShape visit(const luci::CircleVariable *node) final { return use_own(node); }
2464 loco::NodeShape visit(const luci::CircleWhileOut *node) final { return infer_while_out(node); }
2472 bool CircleShapeInferenceRule::recognize(const loco::Dialect *d) const
2474 return CircleDialect::get() == d;
2477 bool CircleShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
2481 assert(node->dialect() == CircleDialect::get());
2483 ShapeInferenceAlgorithm alg;
2484 auto circle_node = loco::must_cast<const CircleNode *>(node);
2486 bool is_shape_undefined = (circle_node->shape_status() == ShapeStatus::UNDEFINED);
2487 bool is_shape_none = (circle_node->shape_status() == ShapeStatus::NOSHAPE);
2488 bool is_scalar = (circle_node->rank() == 0);
2490 if (is_shape_undefined)
2491 shape = circle_node->accept(&alg);
2494 if (is_shape_none || is_scalar)
2495 shape = own_shape(circle_node);
2497 shape = circle_node->accept(&alg);
2500 VERBOSE(l, 1) << "[luci] shape: " << circle_node->name();
2501 VERBOSE(l, 1) << " own_shape: " << own_shape(circle_node)
2502 << " -> infer: " << shape.as<loco::TensorShape>();