From 8b44b0e1d526acc2d85f07adeb332bf6dba516dc Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=90=D0=BD=D0=B4=D1=80=D0=B5=D0=B9=20=D0=A8=D0=B5=D0=B4?= =?utf8?q?=D1=8C=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 19 Dec 2018 19:05:46 +0300 Subject: [PATCH] [nnc] Added Resize implementation in SB (#2398) SoftBackend implementation of Resize Nearest Neighbor Lifted from TFLite Reference ops. Fixed Interpreter implementation. Signed-off-by: Andrei Shedko --- .../nnc/include/core/modelIR/operations/ResizeOp.h | 4 ++ contrib/nnc/passes/interpreter/Interpreter.cpp | 12 ++--- contrib/nnc/passes/soft_backend/CPPGenerator.cpp | 2 + contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp | 9 +++- contrib/nnc/passes/soft_backend/SBSerializer.cpp | 56 ++++++++------------ .../soft_backend/code_snippets/cpp_operations.def | 23 ++++++++ .../soft_backend/code_snippets/cpp_resize.def | 61 ++++++++++++++++++++++ .../nnc/unittests/soft_backend/CPPOperations.cpp | 52 +++++++++++++++++- 8 files changed, 178 insertions(+), 41 deletions(-) create mode 100644 contrib/nnc/passes/soft_backend/code_snippets/cpp_resize.def diff --git a/contrib/nnc/include/core/modelIR/operations/ResizeOp.h b/contrib/nnc/include/core/modelIR/operations/ResizeOp.h index 639a2f4..3d04a86 100644 --- a/contrib/nnc/include/core/modelIR/operations/ResizeOp.h +++ b/contrib/nnc/include/core/modelIR/operations/ResizeOp.h @@ -26,6 +26,10 @@ namespace nnc { namespace mir { namespace ops { +/**@brief Resize operation + * scales are such that output = input * scale for each dimension + * and the number of dimensions matches + */ class ResizeOp : public Operation { public: diff --git a/contrib/nnc/passes/interpreter/Interpreter.cpp b/contrib/nnc/passes/interpreter/Interpreter.cpp index e001f32..a6544e9 100644 --- a/contrib/nnc/passes/interpreter/Interpreter.cpp +++ b/contrib/nnc/passes/interpreter/Interpreter.cpp @@ -351,15 +351,15 @@ void NNInterpreter::visit(ops::ResizeOp& op) { mapByName(&op); auto operand = op.getPrevNodes()[0]; Tensor input(var(operand.op->getId())[operand.index]); - assert(input.getShape().rank() == 4 && "Must be rank 4 (for now)"); switch (op.getMode()) { case ops::ResizeOp::ResizeMethod::nearestNeighbor: { auto scales = op.getScales(); - var(op.getId()) = Fill(op.getOutputShape(0), [&scales, &input, &op](const Index& id) { - const Index in_idx = {static_cast (lroundf(scales[0] * id.at(0))), - static_cast (lroundf(scales[1] * id.at(1))), - static_cast (lroundf(scales[2] * id.at(2))), - static_cast (lroundf(scales[3] * id.at(3)))}; + var(op.getId()) = Fill(op.getOutputShape(0), [&scales, &input](const Index& id) { + Index in_idx; + in_idx.resize(4); + for (int i = 0; i < input.getShape().rank(); i++) { + in_idx.at(i) = static_cast (floorf(id.at(i) / scales[i])); + } return input.at(in_idx); })(); break; diff --git a/contrib/nnc/passes/soft_backend/CPPGenerator.cpp b/contrib/nnc/passes/soft_backend/CPPGenerator.cpp index 2c4b851..9ab20a7 100644 --- a/contrib/nnc/passes/soft_backend/CPPGenerator.cpp +++ b/contrib/nnc/passes/soft_backend/CPPGenerator.cpp @@ -40,6 +40,7 @@ using namespace std; #include "cpp_sqrt.generated.h" #include "cpp_relu.generated.h" #include "cpp_reduce.generated.h" +#include "cpp_resize.generated.h" #include "cpp_softmax.generated.h" #include "cpp_scale.generated.h" #include "cpp_slice.generated.h" @@ -285,6 +286,7 @@ void CPPCodeGenerator::materializeCode(ostream &out, const ModelAnalyzer &ma, co out.write(cpp_conv, sizeof(cpp_conv)); out.write(cpp_depthwise_conv, sizeof(cpp_depthwise_conv)); out.write(cpp_fully_connected, sizeof(cpp_fully_connected)); + out.write(cpp_resize, sizeof(cpp_resize)); out.write(cpp_sigmoid, sizeof(cpp_sigmoid)); out.write(cpp_pool, sizeof(cpp_pool)); out.write(cpp_relu, sizeof(cpp_relu)); diff --git a/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp b/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp index 7f278a0..aedb9f4 100644 --- a/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp +++ b/contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp @@ -43,6 +43,7 @@ #include "core/modelIR/operations/ReduceFOp.h" #include "core/modelIR/operations/ReluOp.h" #include "core/modelIR/operations/ReshapeOp.h" +#include "core/modelIR/operations/ResizeOp.h" #include "core/modelIR/operations/ScaleOp.h" #include "core/modelIR/operations/SigmoidOp.h" #include "core/modelIR/operations/SliceOp.h" @@ -250,7 +251,13 @@ void ModelAnalyzer::visit(ops::ReshapeOp& op) { } void ModelAnalyzer::visit(mir::ops::ResizeOp& op) { - assert(false && "Not implemented"); + switch (op.getMode()) { + case mir::ops::ResizeOp::ResizeMethod::nearestNeighbor: + addOpDescr(&op, "resize"); + break; + default: + assert(false && "Not Implemented!"); + } } void ModelAnalyzer::visit(ops::DropoutOp& op) { diff --git a/contrib/nnc/passes/soft_backend/SBSerializer.cpp b/contrib/nnc/passes/soft_backend/SBSerializer.cpp index 7120f66..fb4d371 100644 --- a/contrib/nnc/passes/soft_backend/SBSerializer.cpp +++ b/contrib/nnc/passes/soft_backend/SBSerializer.cpp @@ -53,8 +53,7 @@ #define UNUSED(x) ((void)(x)) -namespace nnc -{ +namespace nnc { static_assert(std::numeric_limits::is_iec559, "Unsupported float type"); @@ -68,26 +67,23 @@ using nnc::mir::TensorVariant; namespace ops = nnc::mir::ops; -namespace -{ - // Currently there are no operations with more then 4 dimensions in kernels/weights etc supported - const auto MAX_DIMS = 4; - const auto MAX_DIM_SIZE = numeric_limits::max(); - // Assuming there are no large enums - const auto MAX_ENUM_VAL = numeric_limits::max(); +namespace { +// Currently there are no operations with more then 4 dimensions in kernels/weights etc supported +const auto MAX_DIMS = 4; +const auto MAX_DIM_SIZE = numeric_limits::max(); +// Assuming there are no large enums +const auto MAX_ENUM_VAL = numeric_limits::max(); } // unnamed namespace -void Serializer::packData(const void *data, size_t size) -{ - const char *p = static_cast(data); +void Serializer::packData(const void* data, size_t size) { + const char* p = static_cast(data); size_t old_size = _buffer.size(); _buffer.resize(old_size + size); copy(p, p + size, _buffer.data() + old_size); } template -void Serializer::serializeT(const T &obj) -{ +void Serializer::serializeT(const T& obj) { packData(&obj, sizeof(T)); } @@ -98,25 +94,21 @@ void Serializer::serializeT(const T &obj) * @return Integer value that correspond to enumVal */ template -typename underlying_type::type etoi(E enumVal) -{ +typename underlying_type::type etoi(E enumVal) { return static_cast::type>(enumVal); } -void Serializer::serializeShape(const Shape &s) -{ +void Serializer::serializeShape(const Shape& s) { int32_t rank = s.rank(); assert(rank <= MAX_DIMS); serializeT(s.rank()); - for (int32_t i = 0; i < rank; ++i) - { + for (int32_t i = 0; i < rank; ++i) { int32_t dim = s.dim(i); serializeT(dim); } } -void Serializer::serializeTensor(const TensorVariant &t) -{ +void Serializer::serializeTensor(const TensorVariant& t) { // serialize type assert(etoi(t.getDataType()) < MAX_ENUM_VAL); serializeT(etoi(t.getDataType())); @@ -125,15 +117,14 @@ void Serializer::serializeTensor(const TensorVariant &t) assert(eSize <= MAX_DIMS); serializeT(eSize); // serialize shape - const Shape &shape = t.getShape(); + const Shape& shape = t.getShape(); serializeShape(shape); // serialize actual data size_t tSize = eSize * shape.numElements(); size_t oldSize = _buffer.size(); _buffer.reserve(oldSize + tSize); - for (const Index &idx: ShapeRange(shape)) - { + for (const Index& idx: ShapeRange(shape)) { packData(t.at(idx), eSize); } } @@ -178,7 +169,7 @@ void Serializer::visit(ops::Conv2DOp& op) { void Serializer::visit(ops::DepthwiseConv2DOp& op) { _curOp->_paramStartOffset = _buffer.size(); // serialize kernel - const TensorVariant &kernel = op.getKernel(); + const TensorVariant& kernel = op.getKernel(); serializeTensor(kernel); // serialize strides serializeShape(op.getStrides()); @@ -207,8 +198,7 @@ void Serializer::visit(ops::PoolOp& op) { serializePads(op, padsRank); // serialize border type PoolBorderType borderType; - switch (op.getBorderType()) - { + switch (op.getBorderType()) { case ops::PoolOp::BorderType::EMPTY: borderType = PoolBorderType::EMPTY; break; @@ -286,10 +276,8 @@ void Serializer::visit(ops::DropoutOp& op) { serializeT(op.getRate()); } -void Serializer::serialize(list &inferenceSequence) -{ - for (OpDescr &descr: inferenceSequence) - { +void Serializer::serialize(list& inferenceSequence) { + for (OpDescr& descr: inferenceSequence) { _curOp = &descr; descr._op->accept(this); } @@ -356,7 +344,9 @@ void Serializer::visit(mir::ops::SqrtOp& op) { } void Serializer::visit(mir::ops::ResizeOp& op) { - throw PassException("Not implemented yet"); + _curOp->_paramStartOffset = _buffer.size(); + // Result shape is the same as Output shape + serializeShape(op.getOutputShape(0)); } void Serializer::visit(mir::ops::ReduceFOp& op) { diff --git a/contrib/nnc/passes/soft_backend/code_snippets/cpp_operations.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_operations.def index 78c53b2..b3a265c 100644 --- a/contrib/nnc/passes/soft_backend/code_snippets/cpp_operations.def +++ b/contrib/nnc/passes/soft_backend/code_snippets/cpp_operations.def @@ -435,6 +435,29 @@ void fullConnect(Tensor &out, const char *params, const Tensor &in) out.getData(), shapeToDims(out_s)); } +/** + * @brief Resize assuming tflite axis order (NHWC) + */ +void resize(Tensor& out, const char* params, const Tensor& in) { + // The Tensorflow version of this op allows resize on the width and height + // axis only. + const float* input = in.getData(); + assert(in.getShape().getDims() == 4 && "Should be a 4d tensor"); + RuntimeShape in_shape = shapeToRuntimeShape(in.getShape()); + Shape out_shape = deserializeShape(params); + out.reShape(out_shape); + + assert(out_shape.getDims() == 4 && "Should be a 4d tensor"); + RuntimeShape out_runtime = shapeToRuntimeShape(out_shape); + assert(out_shape[0] == in_shape.Dims(0) && out_shape[3] == in_shape.Dims(3) && + "Resize is unly supported over hight and width"); + + ResizeNearestNeighbor( + in_shape, input, + out_shape[1], out_shape[2], + out_runtime, out.getData()); +} + void cappedRelu(Tensor &out, const char *params, const Tensor &in) { const float *input = in.getData(); diff --git a/contrib/nnc/passes/soft_backend/code_snippets/cpp_resize.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_resize.def new file mode 100644 index 0000000..68dde56 --- /dev/null +++ b/contrib/nnc/passes/soft_backend/code_snippets/cpp_resize.def @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +template +inline void ResizeNearestNeighbor( + const RuntimeShape& unextended_input_shape, const T* input_data, + const int32 output_height, const int32 output_width, + const RuntimeShape& unextended_output_shape, T* output_data) { + // Align corners = true is not supported. + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + + const RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + const RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + int32 batches = MatchingDim(input_shape, 0, output_shape, 0); + int32 input_height = input_shape.Dims(1); + int32 input_width = input_shape.Dims(2); + int32 depth = MatchingDim(input_shape, 3, output_shape, 3); + + + // We use float to ensure agreement with the Tensorflow implementation. + const float height_scale = static_cast(input_height) / output_height; + const float width_scale = static_cast(input_width) / output_width; + + const int col_offset = input_shape.Dims(3); + const int row_offset = input_shape.Dims(2) * col_offset; + const int batch_offset = input_shape.Dims(1) * row_offset; + + const T* input_ptr = input_data; + T* output_ptr = output_data; + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + int32 in_y = std::min(static_cast(std::floor(y * height_scale)), + input_height - 1); + const T* y_input_ptr = input_ptr + in_y * row_offset; + for (int x = 0; x < output_width; ++x) { + int32 in_x = std::min(static_cast(std::floor(x * width_scale)), + input_width - 1); + const T* x_input_ptr = y_input_ptr + in_x * col_offset; + memcpy(output_ptr, x_input_ptr, depth * sizeof(T)); + output_ptr += depth; + } + } + input_ptr += batch_offset; + } +} diff --git a/contrib/nnc/unittests/soft_backend/CPPOperations.cpp b/contrib/nnc/unittests/soft_backend/CPPOperations.cpp index 5be131f..099d774 100644 --- a/contrib/nnc/unittests/soft_backend/CPPOperations.cpp +++ b/contrib/nnc/unittests/soft_backend/CPPOperations.cpp @@ -42,6 +42,7 @@ #include "code_snippets/cpp_pool.def" #include "code_snippets/cpp_reduce.def" #include "code_snippets/cpp_relu.def" +#include "code_snippets/cpp_resize.def" #include "code_snippets/cpp_softmax.def" #include "code_snippets/cpp_sqrt.def" #include "code_snippets/cpp_slice.def" @@ -71,6 +72,7 @@ #include "core/modelIR/operations/ReduceFOp.h" #include "core/modelIR/operations/ReluOp.h" #include "core/modelIR/operations/ReshapeOp.h" +#include "core/modelIR/operations/ResizeOp.h" #include "core/modelIR/operations/ScaleOp.h" #include "core/modelIR/operations/SigmoidOp.h" #include "core/modelIR/operations/SliceOp.h" @@ -182,7 +184,7 @@ void fillNTensor(mir::TensorVariant &dst, float start) { */ mir::TensorVariant createNTensor(mir::Shape &shape, float start) { shared_ptr dataBuf( - new char[sizeof(float) * shape.numElements()], default_delete()); + new char[sizeof(float) * shape.numElements()], default_delete()); mir::TensorVariant tensor(shape, dataBuf, mir::DTYPE::FLOAT32, sizeof(float)); fillNTensor(tensor, start); return tensor; @@ -656,6 +658,54 @@ TEST(cpp_operations_test, fully_connected) { createAndRunTestGraph(op_generator, fullConnect, input_ntensors, input_atensor); } +TEST(cpp_operations_test, resize_NN_test) { + mir::Shape test_shapes[] = { + {1, 8, 8, 1}, + {2, 10, 10, 1}, + {1, 11, 11, 2}, + {2, 8, 12, 2}, + {1, 48, 12, 1}, + {1, 48, 48, 1}, + {1, 48, 56, 1} + }; + for (mir::Shape res_shape: test_shapes) { + vector input_shape_data{res_shape.dim(0), 4, 4, res_shape.dim(3)}; + vector> input_ntensors(1); + Tensor input_atensor; + fillTensors(input_ntensors[0], input_atensor, input_shape_data, 1.0f); + auto op_generator = [res_shape](mir::Graph& g, const std::vector& inputs) { + return g.create( + "y", inputs[0], + mir::ops::ResizeOp::ResizeMethod::nearestNeighbor, res_shape); + }; + + createAndRunTestGraph(op_generator, resize, input_ntensors, input_atensor); + } +} + +TEST(cpp_operations_test, resize_NN_test_scales) { + cout << "\n"; + std::vector test_scales[] = { + {1, 2, 2, 1}, + {1, 2, 3, 1}, + {1, 3, 2, 1}, + {1, 2.5, 2, 1}, + {1, 3, 9, 1} + }; + for (const std::vector& scales: test_scales) { + vector input_shape_data{1, 4, 4, 1}; + vector> input_ntensors(1); + Tensor input_atensor; + fillTensors(input_ntensors[0], input_atensor, input_shape_data, 1.0f); + auto op_generator = [scales](mir::Graph& g, const std::vector& inputs) { + return g.create( + "y", inputs[0], + mir::ops::ResizeOp::ResizeMethod::nearestNeighbor, scales); + }; + createAndRunTestGraph(op_generator, resize, input_ntensors, input_atensor); + } +} + template static mir::Operation* createPool(mir::Graph& g, const std::vector& inputs, -- 2.7.4