From c3501839cdce36c0555c888b284e4fd89e64e6f8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=94=D0=B8=D0=BB=D1=88=D0=BE=D0=B4=D0=B6=D0=BE=D0=BD=20?= =?utf8?q?=D0=A3=D0=BC=D1=80=D0=BE=D0=BD=D1=85=D0=BE=D0=BD=D0=BE=D0=B2?= =?utf8?q?=D0=B8=D1=87=20=D0=9F=D0=BE=D1=88=D1=88=D0=BE=D0=B5=D0=B2/AI=20T?= =?utf8?q?ools=20Lab=20/SRR/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 14 Jan 2019 13:09:41 +0300 Subject: [PATCH] Refactoring of Shape::broadcast (#4132) Remove 4 loops and creating two temp vectors and handle all checks in one loop Signed-off-by: Poshshoev Dilshodzhon --- libs/tflite/src/TensorShapeUtils.cpp | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/libs/tflite/src/TensorShapeUtils.cpp b/libs/tflite/src/TensorShapeUtils.cpp index b5d9067..29628cd 100644 --- a/libs/tflite/src/TensorShapeUtils.cpp +++ b/libs/tflite/src/TensorShapeUtils.cpp @@ -11,34 +11,15 @@ nnfw::misc::tensor::Shape broadcast(const nnfw::misc::tensor::Shape &lhs_shape, const uint32_t lhs_rank = lhs_shape.rank(); const uint32_t rhs_rank = rhs_shape.rank(); const uint32_t out_rank = std::max(lhs_rank, rhs_rank); - - // TODO Simplify implementation - std::vector lhs_normalized_dims; - std::vector rhs_normalized_dims; - - for (uint32_t n = 0; n < out_rank - lhs_rank; ++n) - { - lhs_normalized_dims.emplace_back(1); - } - for (uint32_t axis = 0; axis < lhs_rank; ++axis) - { - lhs_normalized_dims.emplace_back(lhs_shape.dim(axis)); - } - - for (uint32_t n = 0; n < out_rank - rhs_rank; ++n) - { - rhs_normalized_dims.emplace_back(1); - } - for (uint32_t axis = 0; axis < rhs_rank; ++axis) - { - rhs_normalized_dims.emplace_back(rhs_shape.dim(axis)); - } + const uint32_t lhs_rank_diff = out_rank - lhs_rank; + const uint32_t rhs_rank_diff = out_rank - rhs_rank; nnfw::misc::tensor::Shape out_shape(out_rank); for (uint32_t axis = 0; axis < out_rank; ++axis) { - out_shape.dim(axis) = std::max(lhs_normalized_dims.at(axis), rhs_normalized_dims.at(axis)); + out_shape.dim(axis) = std::max(axis < lhs_rank_diff ? 1 : lhs_shape.dim(axis - lhs_rank_diff), + axis < rhs_rank_diff ? 1 : rhs_shape.dim(axis - rhs_rank_diff)); } return out_shape; -- 2.7.4