From 54ec402b355425557ab47fbfc44404ba002b80ae Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=A1=D0=B5=D1=80=D0=B3=D0=B5=D0=B9=20=D0=91=D0=B0=D1=80?= =?utf8?q?=D0=B0=D0=BD=D0=BD=D0=B8=D0=BA=D0=BE=D0=B2/AI=20Tools=20Lab=20/S?= =?utf8?q?RR/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 30 Jul 2019 22:19:11 +0300 Subject: [PATCH] [mir] Add support for 0-dimensional tensors (#6039) Fixes compilation error of SuperResolution and StyleTransfer networks. Signed-off-by: Sergei Barannikov --- compiler/mir/src/Shape.cpp | 5 ----- compiler/mir/src/TensorVariant.cpp | 8 +++----- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/compiler/mir/src/Shape.cpp b/compiler/mir/src/Shape.cpp index c9a322a..0ef79b5 100644 --- a/compiler/mir/src/Shape.cpp +++ b/compiler/mir/src/Shape.cpp @@ -28,11 +28,6 @@ void Shape::resize(int32_t size) { _dims.resize(size); } int32_t Shape::numElements() const { - if (rank() == 0) - { - return 0; - } - int32_t res = 1; for (int32_t axis = 0; axis < rank(); ++axis) diff --git a/compiler/mir/src/TensorVariant.cpp b/compiler/mir/src/TensorVariant.cpp index 397b279..2380057 100644 --- a/compiler/mir/src/TensorVariant.cpp +++ b/compiler/mir/src/TensorVariant.cpp @@ -71,13 +71,11 @@ TensorVariant::TensorVariant(const TensorVariant &t_old, const Shape &shape) int axis_old = t_old._shape.rank() - 1; for (int d = shape.rank() - 1; d >= 0; d--) { - if (t_old._shape.dim(axis_old) == 1) - _strides[d] = 0; - else - _strides[d] = t_old._strides[axis_old]; - axis_old--; if (axis_old == -1) break; + if (t_old._shape.dim(axis_old) != 1) + _strides[d] = t_old._strides[axis_old]; + axis_old--; } } -- 2.7.4