2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "util/Utils.h"
19 #include "ir/InternalType.h"
21 #include "util/ShapeInference.h"
22 #include "util/logging.h"
30 namespace shape_inference
40 template <typename T, typename U>
41 typename std::enable_if<std::is_integral<T>::value && std::is_integral<U>::value,
42 typename std::common_type<T, U>::type>::type
43 ceil_div(T dividend, U divisor)
45 assert(dividend > 0 && divisor > 0 && "this implementations is for positive numbers only");
46 return (dividend + divisor - 1) / divisor;
49 // Calculate the result of broadcast of two shapes
50 ir::Shape broadcastShapes(const ir::Shape &lhs_shape, const ir::Shape &rhs_shape)
53 auto max_rank = std::max(lhs_shape.rank(), rhs_shape.rank());
55 for (int idx = 0; idx < max_rank; ++idx)
57 // Go over operands dimensions from right to left
58 int lhs_idx = lhs_shape.rank() - idx - 1;
59 int rhs_idx = rhs_shape.rank() - idx - 1;
61 int32_t lhs_dim = lhs_idx >= 0 ? lhs_shape.dim(lhs_idx) : 1;
62 int32_t rhs_dim = rhs_idx >= 0 ? rhs_shape.dim(rhs_idx) : 1;
64 if (lhs_dim != 1 && rhs_dim != 1 && lhs_dim != rhs_dim)
65 throw std::runtime_error("Incompatible shapes for broadcast");
67 out_shape.prepend(std::max(lhs_dim, rhs_dim));
79 // Calculate output height and width of convolution-like operation
80 std::pair<int, int> calcConvLikeHeightAndWidth(const int in_h, const int in_w, const int ker_h,
81 const int ker_w, const ir::Padding pad,
82 const ir::Stride stride,
83 const ir::Dilation dilation = {1, 1})
85 int32_t out_h = 0, out_w = 0;
86 int32_t effective_filter_w_size = (ker_w - 1) * dilation.width_factor + 1;
87 int32_t effective_filter_h_size = (ker_h - 1) * dilation.height_factor + 1;
90 case ir::PaddingType::SAME:
91 out_h = ceil_div(in_h, stride.vertical);
92 out_w = ceil_div(in_w, stride.horizontal);
94 case ir::PaddingType::VALID:
95 out_h = ceil_div(in_h - effective_filter_h_size + 1, stride.vertical);
96 out_w = ceil_div(in_w - effective_filter_w_size + 1, stride.horizontal);
98 case ir::PaddingType::EXPLICIT:
100 (in_h + pad.param.top + pad.param.bottom - effective_filter_h_size) / stride.vertical + 1;
102 (in_w + pad.param.left + pad.param.right - effective_filter_w_size) / stride.horizontal +
109 return {out_h, out_w};
112 ir::Shape inferEltwiseShape(const ir::Shape &lhs_shape, const ir::Shape &rhs_shape)
114 return broadcastShapes(lhs_shape, rhs_shape);
117 ir::Shape inferArgMaxShape(const ir::Shape &input_shape, int axis, int rank)
120 for (int idx = 0; idx < rank; ++idx)
124 int32_t input_dim = input_shape.dim(idx);
125 out_shape.append(input_dim);
132 ir::Shape inferReduceShape(const ir::Shape &input_shape, const std::vector<int> &axes,
135 int num_axis = axes.size();
136 int input_num_dims = input_shape.rank();
137 if (input_num_dims == 0)
139 ir::Shape out_shape(0);
145 for (int idx = 0; idx < input_num_dims; ++idx)
147 bool is_axis = false;
148 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
150 if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
162 out_shape.append(input_shape.dim(idx));
169 // Calculates size of reducing axis.
170 int num_reduce_axis = num_axis;
171 for (int i = 0; i < num_axis; ++i)
173 int current = axes[i];
176 current += input_num_dims;
178 assert(0 <= current && current < input_num_dims);
179 for (int j = 0; j < i; ++j)
181 int previous = axes[j];
184 previous += input_num_dims;
186 if (current == previous)
193 // Determines output dimensions.
195 int num_skip_axis = 0;
196 for (int idx = 0; idx < input_num_dims; ++idx)
198 bool is_axis = false;
199 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
201 if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
210 out_shape.append(input_shape.dim(idx));
217 ir::Shape inferBatchMatMulShape(const ir::Shape &lhs_shape, const ir::Shape &rhs_shape,
218 const ir::operation::BatchMatMul::Param ¶m)
220 bool adj_x = param.adj_x;
221 bool adj_y = param.adj_y;
222 ir::Shape output_shape;
224 int output_rank = std::max(lhs_shape.rank(), rhs_shape.rank());
226 // Extend lhs and rhs shape
227 ir::Shape extended_lhs_shape(lhs_shape);
228 ir::Shape extended_rhs_shape(rhs_shape);
229 extended_lhs_shape.extendRank(output_rank);
230 extended_rhs_shape.extendRank(output_rank);
232 for (int i = 0; i < output_rank - 2; i++)
234 const int lhs_dim = extended_lhs_shape.dim(i);
235 const int rhs_dim = extended_rhs_shape.dim(i);
236 int broadcast_dim = lhs_dim;
237 if (lhs_dim != rhs_dim)
241 broadcast_dim = rhs_dim;
243 else if (rhs_dim != 1)
245 throw std::runtime_error{"BatchMatMul shape inference: invalid brodcasting input shape"};
249 output_shape.append(broadcast_dim);
252 // Fill in the matmul dimensions.
253 int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
254 int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
256 output_shape.append(extended_lhs_shape.dim(lhs_rows_index));
257 output_shape.append(extended_rhs_shape.dim(rhs_cols_index));
262 ir::Shape inferBroadcastToShape(const ir::Shape wshape, const int32_t *shape_buffer)
264 const int num_elements = wshape.num_elements();
266 assert(num_elements != 0);
267 assert(shape_buffer);
269 ir::Shape new_shape(num_elements);
271 for (int i = 0; i < num_elements; ++i)
273 assert(shape_buffer[i] != 0); // It shouldn't be 0.
274 new_shape.dim(i) = shape_buffer[i];
280 ir::Shape inferConcatShape(const Shapes &in_shapes, const ir::operation::Concat::Param ¶m)
282 const int32_t concat_axis = param.axis >= 0 ? param.axis : in_shapes[0].rank() + param.axis;
283 const auto &first_in_shape = in_shapes[0];
285 // Check that all shapes are equal except for concat axis dimension
286 for (const auto &in_shape : in_shapes)
288 if (in_shape.rank() != first_in_shape.rank())
289 throw std::runtime_error("Rank in all input tensors should be same");
291 for (int64_t dim_idx = 0; dim_idx < in_shape.rank(); ++dim_idx)
292 if (!(dim_idx == concat_axis || in_shape.dim(dim_idx) == first_in_shape.dim(dim_idx)))
293 throw std::runtime_error("All tensor should have same dimension "
294 "except dimension on passed axis");
297 // Calculate output shape
298 ir::Shape out_shape(first_in_shape);
299 out_shape.dim(concat_axis) = 0;
300 for (const auto &in_shape : in_shapes)
301 out_shape.dim(concat_axis) += in_shape.dim(concat_axis);
305 ir::Shape inferConv2DShape(const ir::Shape &in_shape, const ir::Shape &ker_shape,
306 const ir::operation::Conv2D::Param ¶m, ir::Layout layout)
308 auto ifm_shape = in_shape.asFeature(layout);
310 // Kernel format is [depth_out, kernel_height, kernel_width, depth_in]
311 auto kf_shape = ker_shape.asFeature(layout);
312 assert(ifm_shape.C == kf_shape.C);
314 const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W,
315 param.padding, param.stride, param.dilation);
317 return ir::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, kf_shape.N};
320 ir::Shape inferDepthwiseConv2DShape(const ir::Shape &in_shape, const ir::Shape &ker_shape,
321 const ir::operation::DepthwiseConv2D::Param ¶m,
324 assert(layout == ir::Layout::NHWC);
325 auto ifm_shape = in_shape.asFeature(layout);
327 // Kernel format is [1, kernel_height, kernel_width, depth_out]
328 auto kf_shape = ker_shape.asFeature(layout);
329 assert(kf_shape.C == static_cast<int32_t>(ifm_shape.C * param.multiplier));
330 assert(kf_shape.N == 1);
332 const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W,
333 param.padding, param.stride);
335 return ir::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, kf_shape.C};
338 ir::Shape inferExpandDimsShape(const ir::Shape &in_shape, int32_t axis)
340 ir::Shape out_shape(in_shape.rank() + 1);
342 axis = ((axis >= 0) ? axis : /* when axis < 0 */ (out_shape.rank() + axis));
343 if (!(0 <= axis && axis <= in_shape.rank()))
344 throw std::runtime_error("axis of dim is out of range");
346 for (int x = 0, out_x = 0; out_x < out_shape.rank(); ++out_x)
349 out_shape.dim(out_x) = 1;
351 out_shape.dim(out_x) = in_shape.dim(x++);
357 ir::Shape inferFillShape(const ir::Shape &in_shape, const int32_t *buffer)
359 ir::Shape out_shape(in_shape.dim(0));
361 for (int out_x = 0; out_x < out_shape.rank(); ++out_x)
363 out_shape.dim(out_x) = buffer[out_x];
369 ir::Shape inferFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &ker_shape)
371 assert(in_shape.rank() >= 2);
372 assert(ker_shape.rank() == 2);
374 const auto input_size_with_batch = in_shape.num_elements();
375 const auto num_units = ker_shape.dim(0);
376 const auto input_size = ker_shape.dim(1);
377 const auto batch_size = input_size_with_batch / input_size;
378 assert(input_size_with_batch % input_size == 0);
380 return {ir::Shape({static_cast<int32_t>(batch_size), num_units})};
383 ir::Shape inferGatherShape(const ir::Shape &input_shape, const ir::Shape &indices_shape, int axis,
387 const int indices_rank = indices_shape.rank();
388 for (int idx = 0; idx < rank; ++idx)
392 for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
394 out_shape.append(indices_shape.dim(indices_idx));
399 out_shape.append(input_shape.dim(idx));
406 ir::Shape inferOnehotShape(const ir::Shape &input_shape, const int depth, int axis)
409 const auto rank = input_shape.rank() + 1;
410 ir::Shape newShape(rank);
412 axis = (axis == -1) ? (rank - 1) : axis;
414 for (int i = 0; i < rank; ++i)
418 newShape.dim(i) = input_shape.dim(i);
422 newShape.dim(i) = depth;
426 newShape.dim(i) = input_shape.dim(i - 1);
433 ir::Shape inferPackShape(const ir::Shape &input_shape, int axis, int rank, int num)
438 for (int out_idx = 0; out_idx < rank; ++out_idx)
442 out_shape.append(num);
446 out_shape.append(input_shape.dim(in_idx++));
453 ir::Shape inferPadShape(const ir::Shape &in_shape, const int32_t *pad_buf, const size_t num_pads)
455 assert(num_pads % 2 == 0);
456 const int32_t rank = num_pads / 2;
459 for (int32_t i = 0; i < rank; ++i)
461 const auto before_padding = pad_buf[i * 2];
462 const auto after_padding = pad_buf[i * 2 + 1];
464 ret.dim(i) = in_shape.dim(i) + before_padding + after_padding;
470 ir::Shape inferPoolShape(const ir::Shape &in_shape, const ir::operation::Pool2D::Param ¶m,
471 const ir::Layout layout)
473 assert(layout == ir::Layout::NHWC);
474 auto ifm_shape = in_shape.asFeature(layout);
475 const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, param.kh, param.kw,
476 param.padding, param.stride);
477 // Pooling don't change number of channels and batch size
478 return ir::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, ifm_shape.C};
481 ir::Shape inferResizeBilinearShape(const ir::Shape &in_shape, const int32_t output_height,
482 const int32_t output_width)
484 assert(in_shape.rank() == 4);
485 ir::Shape ret(in_shape.rank());
487 ret.dim(0) = in_shape.dim(0);
488 ret.dim(1) = output_height;
489 ret.dim(2) = output_width;
490 ret.dim(3) = in_shape.dim(3);
495 template <typename T> ir::Shape inferRangeShape(T start_val, T limit_val, T delta_val)
497 ir::Shape out_shape(static_cast<int>(1));
500 (std::is_integral<T>::value
501 ? ((std::abs(start_val - limit_val) + std::abs(delta_val) - 1) / std::abs(delta_val))
502 : std::ceil(std::abs((start_val - limit_val) / delta_val)));
506 // template instantiation
507 template ir::Shape inferRangeShape(int start_val, int limit_val, int delta_val);
508 template ir::Shape inferRangeShape(float start_val, float limit_val, float delta_val);
510 ir::Shape inferReshapeShape(const int32_t *shape_buf, const int32_t shape_num_elements,
511 const size_t total_num_elements)
513 ir::Shape ret(shape_num_elements);
514 int32_t flatten_dim = ir::Shape::UNSPECIFIED_DIM;
515 for (int32_t i = 0; i < shape_num_elements; ++i)
517 if (shape_buf[i] < 0)
519 if (flatten_dim != ir::Shape::UNSPECIFIED_DIM)
520 throw std::runtime_error("Reshape: 2nd param has special dim(for flatten) more than twice");
526 ret.dim(i) = shape_buf[i];
529 if (flatten_dim != ir::Shape::UNSPECIFIED_DIM)
530 ret.dim(flatten_dim) = total_num_elements / ret.num_elements();
533 if (total_num_elements != static_cast<size_t>(ret.num_elements()))
534 throw std::runtime_error("Reshape: 2nd param is not compatible with the shape of input");
539 ir::Shape inferSelectShape(const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
540 const ir::Shape &input_false_shape)
542 auto haveSameShapes = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
543 const ir::Shape &input_false_shape) {
544 if ((input_cond_shape.rank() != input_true_shape.rank()) ||
545 input_cond_shape.rank() != input_false_shape.rank())
550 int rank = input_cond_shape.rank();
551 for (int i = 0; i < rank; ++i)
553 if (input_cond_shape.dim(i) != input_true_shape.dim(i) ||
554 input_cond_shape.dim(i) != input_false_shape.dim(i))
563 auto calculateShape = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
564 const ir::Shape &input_false_shape, ir::Shape &new_shape) {
565 ir::Shape cond_shape = input_cond_shape;
566 ir::Shape true_shape = input_true_shape;
567 ir::Shape false_shape = input_false_shape;
569 (cond_shape.rank() >= true_shape.rank()) && (cond_shape.rank() >= false_shape.rank())
571 : (false_shape.rank() >= true_shape.rank() ? false_shape.rank() : true_shape.rank());
573 ir::Shape calculate_shape(most_rank);
575 cond_shape.extendRank(most_rank);
576 true_shape.extendRank(most_rank);
577 false_shape.extendRank(most_rank);
579 for (int i = 0; i < most_rank; ++i)
581 calculate_shape.dim(i) =
582 (cond_shape.dim(i) >= true_shape.dim(i)) && (cond_shape.dim(i) >= false_shape.dim(i))
584 : (false_shape.dim(i) >= true_shape.dim(i) ? false_shape.dim(i) : true_shape.dim(i));
586 if ((cond_shape.dim(i) != calculate_shape.dim(i) && cond_shape.dim(i) != 1) ||
587 (true_shape.dim(i) != calculate_shape.dim(i) && true_shape.dim(i) != 1) ||
588 (false_shape.dim(i) != calculate_shape.dim(i) && false_shape.dim(i) != 1))
594 new_shape = calculate_shape;
599 bool havesame = haveSameShapes(input_cond_shape, input_true_shape, input_false_shape);
602 return input_cond_shape;
606 bool possible = calculateShape(input_cond_shape, input_true_shape, input_false_shape, new_shape);
610 throw std::runtime_error("Broadcasting is not possible.");
616 ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins, const int32_t *sizes)
618 const uint32_t rank = input_shape.rank();
619 ir::Shape out_shape(rank);
621 for (uint32_t idx = 0; idx < rank; ++idx)
623 const auto input_dim = input_shape.dim(idx);
625 // begin is zero-based
626 auto begin = begins[idx];
628 throw std::runtime_error("shape inference Slice: Invalid begin.");
631 auto size = sizes[idx];
633 throw std::runtime_error("shape inference Slice: Invalid size.");
637 size = input_dim - begin;
641 if (input_dim < begin + size)
642 throw std::runtime_error("shape inference Slice: Invalid begin and size.");
644 out_shape.dim(idx) = size;
650 ir::Shape inferSpaceToBatchNDShape(const ir::Shape &input_shape, const ir::Shape &block_shape_shape,
651 const ir::Shape &padding_shape, const int32_t *block_shape_data,
652 const int32_t *padding_data)
654 const uint32_t rank = input_shape.rank();
655 ir::Shape out_shape(rank);
657 // Currently, only 4D NHWC input/output op_context are supported.
658 // The 4D array need to have exactly 2 spatial dimensions.
659 // TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND.
660 const int32_t kInputDimensionNum = 4;
661 const int32_t kBlockSizeDimensionNum = 1;
662 const int32_t kSpatialDimensionNum = 2;
664 UNUSED_RELEASE(kInputDimensionNum);
665 UNUSED_RELEASE(kBlockSizeDimensionNum);
666 UNUSED_RELEASE(block_shape_shape);
667 UNUSED_RELEASE(padding_shape);
669 assert(block_shape_shape.rank() == kBlockSizeDimensionNum);
670 assert(block_shape_shape.dim(0) == kSpatialDimensionNum);
671 assert(padding_shape.dim(0) == kSpatialDimensionNum);
672 assert(padding_shape.dim(1) == 2); // fixed, meaning left/right padding for each element
673 assert(padding_shape.rank() == 2); // fixed, meaning dimension(dim 0) and padding length(dim 1)
675 // Ensures the input height and width (with padding) is a multiple of block
676 // shape height and width.
677 for (int dim = 0; dim < kSpatialDimensionNum; ++dim)
680 (input_shape.dim(dim + 1) + padding_data[dim * 2] + padding_data[dim * 2 + 1]);
682 assert(final_dim_size % block_shape_data[dim] == 0);
684 out_shape.dim(dim + 1) = final_dim_size / block_shape_data[dim];
687 const int output_batch_size = input_shape.dim(0) * block_shape_data[0] * block_shape_data[1];
688 const int output_channel_size = input_shape.dim(3);
690 out_shape.dim(0) = output_batch_size;
691 out_shape.dim(3) = output_channel_size;
696 ir::Shape inferSplitShape(const ir::Shape input_shape, int axis_value, int num_splits)
698 ir::Shape newShape(input_shape);
700 assert(axis_value >= 0);
701 assert(axis_value < input_shape.rank());
703 const int input_size = input_shape.dim(axis_value);
704 assert(input_size % num_splits == 0);
705 const int slice_size = input_size / num_splits;
707 newShape.dim(axis_value) = slice_size;
712 ir::Shape inferSqueezeShape(const ir::Shape &in_shape, const ir::operation::Squeeze::Param ¶m)
714 const int ndims = param.ndim;
715 const int *squeeze_dims = param.dims;
716 bool should_squeeze[8] = {false};
717 int num_squeezed_dims = 0;
718 int shape_rank = in_shape.rank();
721 for (int idx = 0; idx < shape_rank; ++idx)
723 if (in_shape.dim(idx) == 1)
725 should_squeeze[idx] = true;
732 for (int idx = 0; idx < ndims; ++idx)
734 int current = squeeze_dims[idx];
737 current += shape_rank;
740 if (!(current >= 0 && current < shape_rank && in_shape.dim(current) == 1))
742 throw std::runtime_error(
743 "The following conditions must be met: 0 <= dim < Shape rank, dim == 1");
746 if (!should_squeeze[current])
750 should_squeeze[current] = true;
755 ir::Shape out_shape(shape_rank - num_squeezed_dims);
756 for (int in_idx = 0, out_idx = 0; in_idx < shape_rank; ++in_idx)
758 if (!should_squeeze[in_idx])
760 out_shape.dim(out_idx++) = in_shape.dim(in_idx);
767 // helper for for StridedSlice
768 template <typename T>
769 StridedSliceParams buildStridedSliceParams(const T *begin, const T *end, const T *strides,
770 const uint32_t begin_mask, const uint32_t end_mask,
771 const uint32_t shrink_axis_mask, const uint8_t rank)
773 StridedSliceParams op_params;
774 op_params.start_indices_count = rank;
775 op_params.stop_indices_count = rank;
776 op_params.strides_count = rank;
778 for (int i = 0; i < op_params.strides_count; ++i)
780 op_params.start_indices[i] = begin[i];
781 op_params.stop_indices[i] = end[i];
782 op_params.strides[i] = strides[i];
784 assert(op_params.strides[i] != 0);
787 op_params.begin_mask = begin_mask;
788 op_params.ellipsis_mask = 0; // NYI
789 op_params.end_mask = end_mask;
790 op_params.new_axis_mask = 0; // NYI
791 op_params.shrink_axis_mask = shrink_axis_mask;
793 assert(sizeof(op_params.begin_mask) * 4 >= rank);
798 // template instantiation
799 template StridedSliceParams
800 buildStridedSliceParams(const uint32_t *begin, const uint32_t *end, const uint32_t *strides,
801 const uint32_t begin_mask, const uint32_t end_mask,
802 const uint32_t shrink_axis_mask, const uint8_t rank);
804 int Clamp(const int v, const int lo, const int hi)
814 int StartForAxis(const StridedSliceParams ¶ms, const ir::Shape &input_shape, int axis)
816 const auto begin_mask = params.begin_mask;
817 const auto *start_indices = params.start_indices;
818 const auto *strides = params.strides;
819 // Begin with the specified index.
820 int start = start_indices[axis];
822 // begin_mask override
823 if (begin_mask & 1 << axis)
825 if (strides[axis] > 0)
827 // Forward iteration - use the first element. These values will get
828 // clamped below (Note: We could have set them to 0 and axis_size-1, but
829 // use lowest() and max() to maintain symmetry with StopForAxis())
830 start = std::numeric_limits<int>::lowest();
834 // Backward iteration - use the last element.
835 start = std::numeric_limits<int>::max();
839 // Handle negative indices
840 int axis_size = input_shape.dim(axis);
847 start = Clamp(start, 0, axis_size - 1);
852 // Return the "real" index for the end of iteration along that axis. This is an
853 // "end" in the traditional C sense, in that it points to one past the last
854 // element. ie. So if you were iterating through all elements of a 1D array of
855 // size 4, this function would return 4 as the stop, because it is one past the
856 // "real" indices of 0, 1, 2 & 3.
857 int StopForAxis(const StridedSliceParams ¶ms, const ir::Shape &input_shape, int axis,
860 const auto end_mask = params.end_mask;
861 const auto shrink_axis_mask = params.shrink_axis_mask;
862 const auto *stop_indices = params.stop_indices;
863 const auto *strides = params.strides;
865 // Begin with the specified index
866 const bool shrink_axis = shrink_axis_mask & (1 << axis);
867 int stop = stop_indices[axis];
869 // When shrinking an axis, the end position does not matter (and can be
870 // incorrect when negative indexing is used, see Issue #19260). Always use
871 // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
872 // already been adjusted for negative indices.
875 stop = start_for_axis + 1;
879 if (end_mask & (1 << axis))
881 if (strides[axis] > 0)
883 // Forward iteration - use the last element. These values will get
885 stop = std::numeric_limits<int>::max();
889 // Backward iteration - use the first element.
890 stop = std::numeric_limits<int>::lowest();
894 // Handle negative indices
896 const int axis_size = input_shape.dim(axis);
903 // Because the end index points one past the last element, we need slightly
904 // different clamping ranges depending on the direction.
905 if (strides[axis] > 0)
908 stop = Clamp(stop, 0, axis_size);
912 // Backward iteration
913 stop = Clamp(stop, -1, axis_size - 1);
919 ir::Shape inferStridedSliceShape(const ir::Shape &input_shape, const StridedSliceParams &op_params,
924 for (uint32_t idx = 0; idx < rank; ++idx)
926 int32_t stride = op_params.strides[idx];
927 int32_t begin = StartForAxis(op_params, input_shape, idx);
928 int32_t end = StopForAxis(op_params, input_shape, idx, begin);
930 // When shrinking an axis, the end position does not matter (and can be
931 // incorrect when negative indexing is used, see Issue #19260). Always use
932 // begin + 1 to generate a length 1 slice, since begin has
933 // already been adjusted for negative indices by StartForAxis.
934 const bool shrink_axis = op_params.shrink_axis_mask & (1 << idx);
940 int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
941 dim_shape = dim_shape < 0 ? 0 : dim_shape;
944 out_shape.append(dim_shape);
951 ir::Shape inferTileShape(const ir::Shape &in_shape, const int32_t *multiplier)
953 // assert(in_shape.rank() == multiplier.rank());
954 ir::Shape new_Shape(in_shape.rank());
956 for (int i = 0; i < in_shape.rank(); ++i)
958 assert(multiplier[i]); // multiplier[i] shuld not be 0.
959 new_Shape.dim(i) = in_shape.dim(i) * multiplier[i];
964 ir::Shape inferTransposeShape(const ir::Shape &in_shape, const std::vector<int> &perm)
966 if (static_cast<int>(perm.size()) > in_shape.rank())
968 throw std::runtime_error("inferTransposeShape failed, bad rank size: " +
969 std::to_string(static_cast<int>(perm.size())));
971 ir::Shape out_shape(static_cast<int>(perm.size()));
972 for (int idx = 0; idx < static_cast<int>(perm.size()); idx++)
974 if (perm[idx] < 0 || perm[idx] >= static_cast<int>(perm.size()))
976 throw std::runtime_error("inferTransposeShape failed, bad perm value: " +
977 std::to_string(perm[idx]));
979 out_shape.dim(idx) = in_shape.dim(perm[idx]);
984 ir::Shape inferUnpackShape(const ir::Shape &input_shape, int axis, int rank)
988 for (int out_idx = 0; out_idx < rank; out_idx++)
992 out_shape.append(input_shape.dim(out_idx));
999 } // namespace shape_inference
1000 } // namespace onert