From 505a0952edbbcd0d9bcb156ef9abbe70c6f91db4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Vladimir=20Plazun/AI=20Tools=20Lab=20/SRR/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 5 Dec 2018 18:00:20 +0300 Subject: [PATCH] [nnc] Replace zero-dim tensors with tensor of shape {1} (#2513) Fixes multiple issues involving zero-dimension tensors. Includes Bias, Scale, Reduce and Elementwise operations Signed-off-by: Vladimir Plazun --- contrib/nnc/include/passes/common_frontend/shape_helper.h | 5 +++++ contrib/nnc/passes/tflite_frontend/tflite_importer.cpp | 2 ++ contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp | 12 +++++------- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/contrib/nnc/include/passes/common_frontend/shape_helper.h b/contrib/nnc/include/passes/common_frontend/shape_helper.h index 186d2fc..a0b9842 100644 --- a/contrib/nnc/include/passes/common_frontend/shape_helper.h +++ b/contrib/nnc/include/passes/common_frontend/shape_helper.h @@ -32,6 +32,11 @@ public: template mir::Shape ShapeHelper::createShape(const Iterable &iter, std::size_t size) { + //Zero-dim tensor is just a tensor with 1 element + if (size == 0) { + return mir::Shape{1}; + } + mir::Shape sh; sh.resize(static_cast(size)); diff --git a/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp b/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp index b647401..84de4e6 100644 --- a/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp +++ b/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp @@ -330,6 +330,8 @@ std::shared_ptr TfliteImporter::createTensor(const Tensor* t, const Bu Shape tensor_shape = ShapeHelper::createShape(*t->shape(), t->shape()->size()); + assert(tensor_shape.numElements() * elementSize == b->data()->size()); + return std::make_shared(tensor_shape, tensor_buffer_copy, type, elementSize); } diff --git a/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp index b8304a6..1828fda 100644 --- a/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp +++ b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp @@ -208,14 +208,12 @@ std::vector TFLiteOpCreator::convertReducer(InputOps inputs, In auto tensor = mir::Tensor(*params.at(0)); std::vector axes; - if (params.at(0)->getShape().rank() == 0) { - // TODO: Dangerous black magic (Default construced Index is 0 dim, as is 0 dim Tensor) - axes.push_back(tensor.at(Index())); - } else { - for (const auto& i: mir::ShapeRange(tensor.getShape())) { - axes.emplace_back(tensor.at(i)); - } + for (const auto& i: mir::ShapeRange(tensor.getShape())) { + axes.emplace_back(tensor.at(i)); } + + std::sort(axes.begin(), axes.end()); + return createOp( ActivationFunctionType_NONE, inputs[0]->getOutput(0), axes, opts->keep_dims(), ft); -- 2.7.4