From 149e2230f394b30d38c00967e00632f01c99263e Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=B5=9C=EC=84=B1=EC=A7=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Principal=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Fri, 18 May 2018 17:05:20 +0900 Subject: [PATCH] [Gather OP] add gather operation fall-back (#1267) * [Gather OP] add gather operation fall-back This commit introduces GATHER operation which is the first version. - currently, only for float32 and 1D Signed-off-by: SungJin Choi * [Gather Op] Gather operation to CPU executor This commit adds Gather operation in CpuExecutor. -currently, only for float32 and 1D Signed-off-by: SungJin Choi --- runtimes/nn/common/CMakeLists.txt | 1 + runtimes/nn/common/CpuExecutor.cpp | 16 ++++++++- runtimes/nn/common/OperationsUtils.cpp | 33 ++++++++++++++++++ runtimes/nn/common/include/Operations.h | 3 ++ runtimes/nn/common/include/OperationsUtils.h | 2 ++ runtimes/nn/common/operations/Gather.cpp | 52 ++++++++++++++++++++++++++++ 6 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 runtimes/nn/common/operations/Gather.cpp diff --git a/runtimes/nn/common/CMakeLists.txt b/runtimes/nn/common/CMakeLists.txt index 31d2d80..83562f2 100644 --- a/runtimes/nn/common/CMakeLists.txt +++ b/runtimes/nn/common/CMakeLists.txt @@ -23,6 +23,7 @@ SET (CUR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/operations/Reshape.cpp ${CMAKE_CURRENT_SOURCE_DIR}/Logging.cpp ${CMAKE_CURRENT_SOURCE_DIR}/operations/DepthwiseConv2D.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/operations/Gather.cpp ) SET (SRCS ${SRCS} diff --git a/runtimes/nn/common/CpuExecutor.cpp b/runtimes/nn/common/CpuExecutor.cpp index 36a4b21..c494978 100644 --- a/runtimes/nn/common/CpuExecutor.cpp +++ b/runtimes/nn/common/CpuExecutor.cpp @@ -1308,7 +1308,21 @@ int CpuExecutor::executeOperation(const Operation &operation) break; case OperationType::GATHER: { - // TODO-NNRT : implement this operation. + if (!allParametersPresent(2, 1)) + { + return ANEURALNETWORKS_BAD_DATA; + } + const RunTimeOperandInfo &input = mOperands[ins[0]]; + const RunTimeOperandInfo &coords = mOperands[ins[1]]; + + RunTimeOperandInfo &output = mOperands[outs[0]]; + Shape outShape = output.shape(); + + success = gatherPrepare(input.shape(), coords.shape(), &outShape) && + setInfoAndAllocateIfNeeded(&output, outShape) && + gatherGeneric(reinterpret_cast(input.buffer), input.shape(), + reinterpret_cast(coords.buffer), coords.shape(), + reinterpret_cast(output.buffer), outShape); } break; case OperationType::TOPK_V2: diff --git a/runtimes/nn/common/OperationsUtils.cpp b/runtimes/nn/common/OperationsUtils.cpp index 8a49100..9751502 100644 --- a/runtimes/nn/common/OperationsUtils.cpp +++ b/runtimes/nn/common/OperationsUtils.cpp @@ -586,5 +586,38 @@ bool hashtableLookupPrepare(const Shape &lookupShape, const Shape &keyShape, return true; } +bool gatherPrepare(const Shape &inputShape, const Shape &coordsShape, Shape *outputShape) +{ + // Only INT32 positions are supported. + NN_OPS_CHECK(coordsShape.type == OperandType::TENSOR_INT32); + // Check that input and output types match. + NN_OPS_CHECK(inputShape.type == outputShape->type); + // TODO: Currently, only 0D or 1D coordsShape are currently supported. Other dimensions are needed + NN_OPS_CHECK(getNumberOfDimensions(coordsShape) <= 1); + + // TODO: other dimension + switch (inputShape.type) + { + case OperandType::TENSOR_FLOAT32: + case OperandType::TENSOR_INT32: + { + // Fully supported by reference_ops::Gather. + } + break; + default: + std::cerr << "Only float32 and string types are supported" << std::endl; + return false; + } + + // calculate dimension (axis is supposed to be 0) + const int num_dimensions = + (getNumberOfDimensions(inputShape) - 1) + getNumberOfDimensions(coordsShape); + NN_OPS_CHECK(num_dimensions >= 0); + + outputShape->type = inputShape.type; + + return true; +} + } // namespace rt } // namespace nnfw diff --git a/runtimes/nn/common/include/Operations.h b/runtimes/nn/common/include/Operations.h index 585a689..29ff6e6 100644 --- a/runtimes/nn/common/include/Operations.h +++ b/runtimes/nn/common/include/Operations.h @@ -161,6 +161,9 @@ bool depthToSpaceGeneric(const uint8_t *inputData, const Shape &inputShape, int3 bool spaceToDepthGeneric(const uint8_t *inputData, const Shape &inputShape, int32_t blockSize, uint8_t *outputData, const Shape &outputShape); +bool gatherGeneric(const uint8_t *inputData, const Shape &inputShape, const int32_t *coordsData, + const Shape &coordsShape, uint8_t *outputData, const Shape &outputShape); + } // namespace rt } // namespace nnfw diff --git a/runtimes/nn/common/include/OperationsUtils.h b/runtimes/nn/common/include/OperationsUtils.h index 4b82fcf..542b26a 100644 --- a/runtimes/nn/common/include/OperationsUtils.h +++ b/runtimes/nn/common/include/OperationsUtils.h @@ -186,6 +186,8 @@ bool embeddingLookupPrepare(const Shape &valueShape, const Shape &lookupShape, S bool hashtableLookupPrepare(const Shape &lookupShape, const Shape &keyShape, const Shape &valueShape, Shape *outputShape, Shape *hitShape); +bool gatherPrepare(const Shape &input, const Shape &coords, Shape *output); + #define ANDROID_NN_MACRO_DISPATCH_INTERNAL(macro) \ case (int32_t)FusedActivationFunc::NONE: \ macro(kNone); \ diff --git a/runtimes/nn/common/operations/Gather.cpp b/runtimes/nn/common/operations/Gather.cpp new file mode 100644 index 0000000..a9175b7 --- /dev/null +++ b/runtimes/nn/common/operations/Gather.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Contains the implementation of the operations. + +#define LOG_TAG "Operations" + +#include "Operations.h" +#include "OperationsUtils.h" + +#include "internal/optimized/optimized_ops.h" + +namespace nnfw +{ +namespace rt +{ + +bool gatherGeneric(const uint8_t *inputData, const Shape &inputShape, const int32_t *coordsData, + const Shape &coordsShape, uint8_t *outputData, const Shape &outputShape) +{ + // TODO: other types + if (inputShape.type == OperandType::TENSOR_FLOAT32) + { + optimized_ops::Gather( + reinterpret_cast(inputData), convertShapeToDims(inputShape), + reinterpret_cast(coordsData), convertShapeToDims(coordsShape), + reinterpret_cast(outputData), convertShapeToDims(outputShape)); + } + else + { + LOG(ERROR) << "Unsupported data type"; + return false; + } + return true; +} + +} // namespace rt +} // namespace nnfw -- 2.7.4