[mir] Add support for 0-dimensional tensors (#6039)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Tue, 30 Jul 2019 19:19:11 +0000 (22:19 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Tue, 30 Jul 2019 19:19:11 +0000 (22:19 +0300)
Fixes compilation error of SuperResolution and StyleTransfer networks.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/mir/src/Shape.cpp
compiler/mir/src/TensorVariant.cpp

index c9a322a..0ef79b5 100644 (file)
@@ -28,11 +28,6 @@ void Shape::resize(int32_t size) { _dims.resize(size); }
 
 int32_t Shape::numElements() const
 {
-  if (rank() == 0)
-  {
-    return 0;
-  }
-
   int32_t res = 1;
 
   for (int32_t axis = 0; axis < rank(); ++axis)
index 397b279..2380057 100644 (file)
@@ -71,13 +71,11 @@ TensorVariant::TensorVariant(const TensorVariant &t_old, const Shape &shape)
   int axis_old = t_old._shape.rank() - 1;
   for (int d = shape.rank() - 1; d >= 0; d--)
   {
-    if (t_old._shape.dim(axis_old) == 1)
-      _strides[d] = 0;
-    else
-      _strides[d] = t_old._strides[axis_old];
-    axis_old--;
     if (axis_old == -1)
       break;
+    if (t_old._shape.dim(axis_old) != 1)
+      _strides[d] = t_old._strides[axis_old];
+    axis_old--;
   }
 }