2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "TFLShapeInferenceRule.h"
19 #include "Dialect/IR/TFLNodes.h"
20 #include "Dialect/IR/TFLDialect.h"
21 #include "Dialect/IR/TFLNodeVisitor.h"
25 #include <oops/InternalExn.h>
34 // Call this for TFLAvgPool2D and TFLMaxPool2D only
35 template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node)
37 EXO_ASSERT(loco::shape_known(node->value()), "Shape must be known");
39 auto ifm_shape = loco::shape_get(node->value()).template as<loco::TensorShape>();
40 assert(ifm_shape.rank() == 4);
42 uint32_t input_height = ifm_shape.dim(1).value();
43 uint32_t input_width = ifm_shape.dim(2).value();
44 uint32_t stride_height = node->stride()->h();
45 uint32_t stride_width = node->stride()->w();
46 uint32_t window_height = node->filter()->h();
47 uint32_t window_width = node->filter()->w();
48 uint32_t dilation_height = 1; // dilation for TFLAvgPool2D and TFLMaxPool2D is 1
49 uint32_t dilation_width = 1;
50 uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
51 uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
53 uint32_t output_height = 0;
54 uint32_t output_width = 0;
56 if (node->padding() == locoex::Padding::VALID)
58 output_height = (input_height + stride_height - effective_window_height) / stride_height;
59 output_width = (input_width + stride_width - effective_window_width) / stride_width;
61 else if (node->padding() == locoex::Padding::SAME)
63 output_height = (input_height + stride_height - 1) / stride_height;
64 output_width = (input_width + stride_width - 1) / stride_width;
67 EXO_ASSERT(false, "Wrong padding type");
69 loco::TensorShape ofm_shape;
71 ofm_shape.dim(0) = ifm_shape.dim(0);
72 ofm_shape.dim(1) = output_height;
73 ofm_shape.dim(2) = output_width;
74 ofm_shape.dim(3) = ifm_shape.dim(3);
76 return loco::NodeShape{ofm_shape};
80 * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
84 * auto expanded_tensor_shape = expand(tensor_shape).to(N);
86 class TensorShapeExpander
89 TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
95 loco::TensorShape to(uint32_t output_rank)
97 auto const &input_shape = _shape;
98 uint32_t const input_rank = input_shape.rank();
100 assert(input_rank <= output_rank && "Cannot shrink rank");
101 uint32_t const axis_shift = output_rank - input_rank;
103 loco::TensorShape output_shape;
105 output_shape.rank(output_rank);
106 for (uint32_t axis = 0; axis < output_rank; ++axis)
108 output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
115 const loco::TensorShape _shape;
119 * @brief Expand shape x and y to same rank by align right and filling with 1
121 void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
123 auto x_rank = x.rank();
124 auto y_rank = y.rank();
126 if (x_rank == y_rank)
129 TensorShapeExpander x_exp(x);
130 TensorShapeExpander y_exp(y);
132 auto xy_rank = std::max(x_rank, y_rank);
134 x = x_rank > y_rank ? x : x_exp.to(xy_rank);
135 y = y_rank > x_rank ? y : y_exp.to(xy_rank);
139 * @brief Returns shape of expanded dimension of input x and y having same rank
141 loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
143 assert(x.rank() == y.rank());
145 auto rank = x.rank();
147 loco::TensorShape output_shape;
149 output_shape.rank(rank);
150 for (uint32_t axis = 0; axis < rank; ++axis)
152 assert(x.dim(axis).known() && y.dim(axis).known());
154 auto x_dim = x.dim(axis).value();
155 auto y_dim = y.dim(axis).value();
157 // each dimension of x and y should be same or one must be 1 if different
158 if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
159 INTERNAL_EXN("Cannot produce expand_dimension of two shapes");
161 output_shape.dim(axis) = std::max(x_dim, y_dim);
167 loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
172 expand_rank(x_match, y_match);
174 auto output_shape = expand_dimension(x_match, y_match);
180 * @brief Class to infer the shape of TFLNode
182 * @note All TFLNode's inputs and outputs are always loco::Domain::Tensor
184 class ShapeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::NodeShape>
187 loco::NodeShape visit(const locoex::TFLAdd *node) final
189 auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
190 auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
192 auto output_shape = broadcast_shape(x_shape, y_shape);
194 return loco::NodeShape{output_shape};
197 loco::NodeShape visit(const locoex::TFLAveragePool2D *node) final
199 return infer_pool_2d_shape(node);
202 loco::NodeShape visit(const locoex::TFLConcatenation *node) final
204 // TODO Support when TFLConcatenation has 0 input
205 assert(node->numValues() > 0);
207 auto axis = node->axis();
208 auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>();
210 loco::TensorShape output_shape;
212 output_shape.rank(first_shape.rank());
213 for (uint32_t i = 0; i < output_shape.rank(); ++i)
214 output_shape.dim(i) = first_shape.dim(i);
216 for (uint32_t i = 1; i < node->numValues(); ++i)
218 auto input_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>();
220 for (uint32_t j = 0; j < output_shape.rank(); ++j)
223 output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value();
225 assert(output_shape.dim(j) == input_shape.dim(j));
229 return loco::NodeShape{output_shape};
232 loco::NodeShape visit(const locoex::TFLConst *node) final
234 loco::TensorShape shape;
236 shape.rank(node->rank());
237 for (uint32_t axis = 0; axis < node->rank(); axis++)
238 shape.dim(axis) = node->dim(axis);
240 return loco::NodeShape{shape};
243 loco::NodeShape visit(const locoex::TFLConv2D *node) final
245 auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
246 auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
248 assert(ifm_shape.rank() == 4);
249 assert(ker_shape.rank() == 4);
250 assert(ifm_shape.dim(3) == ker_shape.dim(3));
252 uint32_t input_height = ifm_shape.dim(1).value();
253 uint32_t input_width = ifm_shape.dim(2).value();
254 uint32_t stride_height = node->stride()->h();
255 uint32_t stride_width = node->stride()->w();
256 uint32_t ker_height = ker_shape.dim(1).value();
257 uint32_t ker_width = ker_shape.dim(2).value();
258 uint32_t dilation_height = 1;
259 uint32_t dilation_width = 1;
260 uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
261 uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
263 uint32_t output_height = 0;
264 uint32_t output_width = 0;
266 if (node->padding() == locoex::Padding::VALID)
268 output_height = (input_height + stride_height - effective_ker_height) / stride_height;
269 output_width = (input_width + stride_width - effective_ker_width) / stride_width;
271 else if (node->padding() == locoex::Padding::SAME)
273 output_height = (input_height + stride_height - 1) / stride_height;
274 output_width = (input_width + stride_width - 1) / stride_width;
277 EXO_ASSERT(false, "Wrong padding type");
279 loco::TensorShape ofm_shape;
281 ofm_shape.dim(0) = ifm_shape.dim(0);
282 ofm_shape.dim(1) = output_height;
283 ofm_shape.dim(2) = output_width;
284 ofm_shape.dim(3) = ker_shape.dim(0);
286 return loco::NodeShape{ofm_shape};
289 loco::NodeShape visit(const locoex::TFLDepthwiseConv2D *node) final
291 auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
292 auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM
294 assert(ifm_shape.rank() == 4);
295 assert(ker_shape.rank() == 4);
296 assert(ker_shape.dim(0).value() == 1);
298 uint32_t input_height = ifm_shape.dim(1).value();
299 uint32_t input_width = ifm_shape.dim(2).value();
300 uint32_t stride_height = node->stride()->h();
301 uint32_t stride_width = node->stride()->w();
302 uint32_t ker_height = ker_shape.dim(1).value();
303 uint32_t ker_width = ker_shape.dim(2).value();
304 uint32_t dilation_height = 1;
305 uint32_t dilation_width = 1;
306 uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
307 uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
309 uint32_t output_height = 0;
310 uint32_t output_width = 0;
312 if (node->padding() == locoex::Padding::VALID)
314 output_height = (input_height + stride_height - effective_ker_height) / stride_height;
315 output_width = (input_width + stride_width - effective_ker_width) / stride_width;
317 else if (node->padding() == locoex::Padding::SAME)
319 output_height = (input_height + stride_height - 1) / stride_height;
320 output_width = (input_width + stride_width - 1) / stride_width;
323 EXO_ASSERT(false, "Wrong padding type");
325 loco::TensorShape ofm_shape;
327 ofm_shape.dim(0) = ifm_shape.dim(0);
328 ofm_shape.dim(1) = output_height;
329 ofm_shape.dim(2) = output_width;
330 ofm_shape.dim(3) = ker_shape.dim(3);
332 return loco::NodeShape{ofm_shape};
335 loco::NodeShape visit(const locoex::TFLDiv *node) final
337 auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
338 auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
340 auto output_shape = broadcast_shape(x_shape, y_shape);
342 return loco::NodeShape{output_shape};
345 loco::NodeShape visit(const locoex::TFLFullyConnected *node) final
347 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
348 auto weights_shape = loco::shape_get(node->weights()).as<loco::TensorShape>();
350 // Checking shape capability for multiplication
351 EXO_ASSERT(input_shape.rank() == 2, "NYI for input shape rank > 2");
352 EXO_ASSERT(weights_shape.rank() == 2, "Incompatible weights rank for fully connected");
353 EXO_ASSERT(input_shape.dim(1) == weights_shape.dim(1),
354 "Incompatible shapes for fully connected");
356 loco::TensorShape out_shape;
359 out_shape.dim(0) = input_shape.dim(0);
360 out_shape.dim(1) = weights_shape.dim(0);
362 return loco::NodeShape{out_shape};
365 loco::NodeShape visit(const locoex::TFLMaximum *node) final
367 auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
368 auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
370 auto output_shape = broadcast_shape(x_shape, y_shape);
372 return loco::NodeShape{output_shape};
375 loco::NodeShape visit(const locoex::TFLMaxPool2D *node) final
377 return infer_pool_2d_shape(node);
380 loco::NodeShape visit(const locoex::TFLMean *node) final
382 const loco::DataType S32 = loco::DataType::S32;
384 auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
385 auto reduction_indices = dynamic_cast<locoex::TFLConst *>(node->reduction_indices());
388 // TODO support non-const case
389 EXO_ASSERT(reduction_indices, "Only support constant reduction_indices");
390 // TODO support other data type
391 EXO_ASSERT(reduction_indices->dtype() == S32, "Only support int 32");
394 std::vector<int32_t> reduction_values;
396 for (uint32_t i = 0; i < reduction_indices->size<S32>(); ++i)
398 int32_t axis = reduction_indices->at<S32>(i);
400 axis += input_shape.rank();
401 if (not(0 <= axis and axis < static_cast<int32_t>(input_shape.rank())))
402 INTERNAL_EXN_V("Invalid reduction axis for MEAN", oops::to_uint32(axis));
403 reduction_values.push_back(axis);
406 loco::TensorShape output_shape;
408 if (node->keep_dims())
410 output_shape.rank(input_shape.rank());
411 for (uint32_t i = 0; i < input_shape.rank(); ++i)
412 output_shape.dim(i) = input_shape.dim(i);
413 for (uint32_t i = 0; i < reduction_values.size(); ++i)
414 output_shape.dim(reduction_values.at(i)) = 1;
418 std::vector<bool> check_reduce(input_shape.rank(), false);
419 for (uint32_t i = 0; i < reduction_values.size(); ++i)
420 check_reduce.at(reduction_values.at(i)) = true;
422 uint32_t reduce_cnt = 0;
423 for (uint32_t i = 0; i < check_reduce.size(); ++i)
424 if (check_reduce.at(i))
427 output_shape.rank(input_shape.rank() - reduce_cnt);
428 for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
429 if (check_reduce.at(i) == false)
430 output_shape.dim(j++) = i;
433 return loco::NodeShape{output_shape};
436 loco::NodeShape visit(const locoex::TFLMul *node) final
438 auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
439 auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
441 auto output_shape = broadcast_shape(x_shape, y_shape);
443 return loco::NodeShape{output_shape};
446 loco::NodeShape visit(const locoex::TFLRelu *node) final
448 auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
450 return loco::NodeShape{input_shape};
453 loco::NodeShape visit(const locoex::TFLRelu6 *node) final
455 auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
457 return loco::NodeShape{input_shape};
461 * @note TFLReshape has new shape info in two places: 2nd input and attribute.
462 * This shape inference forces both to exist, and match each other.
463 * When this condition satisfied, it return the inferred shape
465 * TODO Change this policy when not appropriate
467 loco::NodeShape visit(const locoex::TFLReshape *node) final
469 const loco::DataType S32 = loco::DataType::S32;
471 loco::TensorShape shape_by_input;
473 EXO_ASSERT(node->shape(), "2nd input shape() should not be nullptr");
475 // Only support node's shape() is TFLConst with S32
476 // TODO support other node with other types
477 auto const_shape_node = dynamic_cast<locoex::TFLConst *>(node->shape());
478 EXO_ASSERT(const_shape_node, "Only support TFLConst for shape of TFLReshape");
479 EXO_ASSERT(const_shape_node->dtype() == S32, "Only support int32 TFLConst");
481 if (const_shape_node->rank() != 1)
482 INTERNAL_EXN_V("Only support rank 1 TFLConst", oops::to_uint32(const_shape_node->rank()));
484 shape_by_input.rank(const_shape_node->dim(0).value());
486 for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
488 EXO_ASSERT(const_shape_node->at<S32>(axis) > 0, "Dimension should be > 0")
489 shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
493 loco::TensorShape shape_by_attr;
495 shape_by_attr.rank(node->newShape()->rank());
497 for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
499 EXO_ASSERT(node->newShape()->dim(axis) > 0, "Dimension should be > 0")
500 shape_by_attr.dim(axis) = node->newShape()->dim(axis);
504 EXO_ASSERT(shape_by_input == shape_by_attr,
505 "Warning: Two new shape information mismatched for TFLReshape");
507 return loco::NodeShape{shape_by_input};
510 loco::NodeShape visit(const locoex::TFLRsqrt *node) final
512 auto input_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
514 return loco::NodeShape{input_shape};
519 loco::NodeShape visit(const locoex::TFLSqrt *node) final
521 auto input_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
523 return loco::NodeShape{input_shape};
526 loco::NodeShape visit(const locoex::TFLSquaredDifference *node) final
528 auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
529 auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
531 auto output_shape = broadcast_shape(x_shape, y_shape);
533 return loco::NodeShape{output_shape};
536 loco::NodeShape visit(const locoex::TFLSub *node) final
538 auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
539 auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
541 auto output_shape = broadcast_shape(x_shape, y_shape);
543 return loco::NodeShape{output_shape};
548 /// @brief Returns output shape of transpose. Use loco::ConstGen and locoex::TFLConst for ConstT.
549 template <class ConstT>
550 loco::TensorShape output_shape_of_transpose(loco::TensorShape input_shape,
551 const ConstT *perm_node)
553 loco::TensorShape output_shape;
554 output_shape.rank(input_shape.rank());
556 assert(perm_node->dtype() == loco::DataType::S32);
557 assert(input_shape.rank() == perm_node->template size<loco::DataType::S32>());
559 for (uint32_t out_axis = 0; out_axis < output_shape.rank(); out_axis++)
561 auto new_dim = perm_node->template at<loco::DataType::S32>(out_axis);
562 output_shape.dim(new_dim) = input_shape.dim(out_axis);
568 loco::NodeShape visit(const locoex::TFLTranspose *node) final
570 auto input_shape = loco::shape_get(node->a()).as<loco::TensorShape>();
572 auto canon_perm = dynamic_cast<loco::ConstGen *>(node->perm());
573 auto tfl_perm = dynamic_cast<locoex::TFLConst *>(node->perm());
577 return loco::NodeShape{output_shape_of_transpose(input_shape, canon_perm)};
581 return loco::NodeShape{output_shape_of_transpose(input_shape, tfl_perm)};
584 INTERNAL_EXN("perm of TFLTranspose should be either ConstGen or TFLConst");
587 loco::NodeShape visit(const locoex::TFLTransposeConv *node) final
589 // TransposeConv's output shape is written in its 'inputSizes' argument
590 auto input_sizes_const = dynamic_cast<locoex::TFLConst *>(node->inputSizes());
591 EXO_ASSERT(input_sizes_const, "Only support when TFLTransposeConv's inputSizes is TFLConst")
592 EXO_ASSERT(input_sizes_const->dtype() == loco::DataType::S32, "Only support S32 dtype")
593 EXO_ASSERT(input_sizes_const->rank() == 1 && input_sizes_const->dim(0).value() == 4,
594 "Only support rank 1 with 4 entries")
596 loco::TensorShape shape;
599 for (uint32_t axis = 0; axis < 4; ++axis)
600 shape.dim(axis) = input_sizes_const->at<loco::DataType::S32>(axis);
602 return loco::NodeShape{shape};
611 bool TFLShapeInferenceRule::recognize(const loco::Dialect *d) const
613 return TFLDialect::get() == d;
616 bool TFLShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
618 assert(node->dialect() == TFLDialect::get());
619 assert(dynamic_cast<const TFLNode *>(node) != nullptr);
621 ShapeInferenceAlgorithm alg;
622 shape = dynamic_cast<const TFLNode *>(node)->accept(&alg);
627 } // namespace locoex