From 6b311b09debb7865ce97e5c2f74714c20851a6ac Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Fri, 1 Jun 2018 11:35:13 +0900 Subject: [PATCH] [support.tflite] Generic TensorView (#1494) This commit revises TensorView class as a generic class. Signed-off-by: Jonghyun Park --- include/support/tflite/TensorView.h | 42 +++++++++++++---- libs/support/tflite/src/TensorView.cpp | 70 ----------------------------- libs/support/tflite/src/TensorView.test.cpp | 17 +++++++ 3 files changed, 51 insertions(+), 78 deletions(-) delete mode 100644 libs/support/tflite/src/TensorView.cpp diff --git a/include/support/tflite/TensorView.h b/include/support/tflite/TensorView.h index 35c90a3..0475a4b 100644 --- a/include/support/tflite/TensorView.h +++ b/include/support/tflite/TensorView.h @@ -31,30 +31,56 @@ namespace support namespace tflite { -template class TensorView; - -template<> class TensorView final : public nnfw::util::tensor::Reader +template class TensorView final : public nnfw::util::tensor::Reader { public: - TensorView(const nnfw::util::tensor::Shape &shape, float *base); + TensorView(const nnfw::util::tensor::Shape &shape, T *base) + : _shape{shape}, _base{base} + { + // Set 'stride' + _stride.init(_shape); + } public: const nnfw::util::tensor::Shape &shape(void) const { return _shape; } public: - float at(const nnfw::util::tensor::Index &index) const override; - float &at(const nnfw::util::tensor::Index &index); + T at(const nnfw::util::tensor::Index &index) const override + { + const auto offset = _stride.offset(index); + return *(_base + offset); + } + +public: + T &at(const nnfw::util::tensor::Index &index) + { + const auto offset = _stride.offset(index); + return *(_base + offset); + } private: nnfw::util::tensor::Shape _shape; public: - float *_base; + T *_base; nnfw::util::tensor::NonIncreasingStride _stride; public: // TODO Introduce Operand ID class - static TensorView make(::tflite::Interpreter &interp, int operand_id); + static TensorView make(::tflite::Interpreter &interp, int tensor_index) + { + auto tensor_ptr = interp.tensor(tensor_index); + + // Set 'shape' + nnfw::util::tensor::Shape shape(tensor_ptr->dims->size); + + for (uint32_t axis = 0; axis < shape.rank(); ++axis) + { + shape.dim(axis) = tensor_ptr->dims->data[axis]; + } + + return TensorView(shape, interp.typed_tensor(tensor_index)); + } }; } // namespace tflite diff --git a/libs/support/tflite/src/TensorView.cpp b/libs/support/tflite/src/TensorView.cpp deleted file mode 100644 index 2618d1c..0000000 --- a/libs/support/tflite/src/TensorView.cpp +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "support/tflite/TensorView.h" - -#include - -namespace nnfw -{ -namespace support -{ -namespace tflite -{ - -TensorView::TensorView(const nnfw::util::tensor::Shape &shape, float *base) - : _shape{shape}, _base{base} -{ - // Set 'stride' - _stride.init(_shape); -} - -float TensorView::at(const nnfw::util::tensor::Index &index) const -{ - const auto offset = _stride.offset(index); - - return *(_base + offset); -} - -float &TensorView::at(const nnfw::util::tensor::Index &index) -{ - const auto offset = _stride.offset(index); - - return *(_base + offset); -} - -TensorView TensorView::make(::tflite::Interpreter &interp, int tensor_index) -{ - auto tensor_ptr = interp.tensor(tensor_index); - - // TODO Enable the following assets - // assert(isFloatTensor(tensor_ptr)); - // assert(isFeatureTensor(tensor_ptr)); - - // Set 'shape' - nnfw::util::tensor::Shape shape(tensor_ptr->dims->size); - - for (uint32_t axis = 0; axis < shape.rank(); ++axis) - { - shape.dim(axis) = tensor_ptr->dims->data[axis]; - } - - return TensorView(shape, interp.typed_tensor(tensor_index)); -} - -} // namespace tflite -} // namespace support -} // namespace nnfw diff --git a/libs/support/tflite/src/TensorView.test.cpp b/libs/support/tflite/src/TensorView.test.cpp index 61a2723..1d3a705 100644 --- a/libs/support/tflite/src/TensorView.test.cpp +++ b/libs/support/tflite/src/TensorView.test.cpp @@ -18,6 +18,21 @@ #include +void int_test(void) +{ + int value[6] = {1, 2, 3, 4, 5, 6}; + + const nnfw::util::tensor::Shape shape{2, 3}; + const nnfw::support::tflite::TensorView view{shape, value}; + + assert(view.at(nnfw::util::tensor::Index{0, 0}) == 1); + assert(view.at(nnfw::util::tensor::Index{0, 1}) == 2); + assert(view.at(nnfw::util::tensor::Index{0, 2}) == 3); + assert(view.at(nnfw::util::tensor::Index{1, 0}) == 4); + assert(view.at(nnfw::util::tensor::Index{1, 1}) == 5); + assert(view.at(nnfw::util::tensor::Index{1, 2}) == 6); +} + int main(int argc, char **argv) { float value[6] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -32,5 +47,7 @@ int main(int argc, char **argv) assert(view.at(nnfw::util::tensor::Index{1, 1}) == 5.0f); assert(view.at(nnfw::util::tensor::Index{1, 2}) == 6.0f); + int_test(); + return 0; } -- 2.7.4