[nnc] Replace zero-dim tensors with tensor of shape {1} (#2513)
authorVladimir Plazun/AI Tools Lab /SRR/Engineer/삼성전자 <v.plazun@partner.samsung.com>
Wed, 5 Dec 2018 15:00:20 +0000 (18:00 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Wed, 5 Dec 2018 15:00:20 +0000 (18:00 +0300)
Fixes multiple issues involving zero-dimension tensors. Includes Bias, Scale, Reduce and Elementwise operations

Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
contrib/nnc/include/passes/common_frontend/shape_helper.h
contrib/nnc/passes/tflite_frontend/tflite_importer.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp

index 186d2fc..a0b9842 100644 (file)
@@ -32,6 +32,11 @@ public:
 template<typename Iterable>
 mir::Shape ShapeHelper::createShape(const Iterable &iter, std::size_t size)
 {
+  //Zero-dim tensor is just a tensor with 1 element
+  if (size == 0) {
+    return mir::Shape{1};
+  }
+
   mir::Shape sh;
   sh.resize(static_cast<int32_t>(size));
 
index b647401..84de4e6 100644 (file)
@@ -330,6 +330,8 @@ std::shared_ptr<IrTensor> TfliteImporter::createTensor(const Tensor* t, const Bu
 
   Shape tensor_shape = ShapeHelper::createShape(*t->shape(), t->shape()->size());
 
+  assert(tensor_shape.numElements() * elementSize == b->data()->size());
+
   return std::make_shared<IrTensor>(tensor_shape, tensor_buffer_copy, type, elementSize);
 }
 
index b8304a6..1828fda 100644 (file)
@@ -208,14 +208,12 @@ std::vector<mir::Operation*> TFLiteOpCreator::convertReducer(InputOps inputs, In
   auto tensor = mir::Tensor<int>(*params.at(0));
   std::vector<int32_t> axes;
 
-  if (params.at(0)->getShape().rank() == 0) {
-    // TODO: Dangerous black magic (Default construced Index is 0 dim, as is 0 dim Tensor)
-    axes.push_back(tensor.at(Index()));
-  } else {
-    for (const auto& i: mir::ShapeRange(tensor.getShape())) {
-      axes.emplace_back(tensor.at(i));
-    }
+  for (const auto& i: mir::ShapeRange(tensor.getShape())) {
+    axes.emplace_back(tensor.at(i));
   }
+
+  std::sort(axes.begin(), axes.end());
+
   return createOp<ops::ReduceFOp>(
     ActivationFunctionType_NONE, inputs[0]->getOutput(0),
     axes, opts->keep_dims(), ft);