From 6e612909f703eccb2017655052bca85da1c936ec Mon Sep 17 00:00:00 2001 From: =?utf8?q?Vishal=20Keshav/System=20SW=20/SRI-Bangalore/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 31 Oct 2018 06:31:43 +0530 Subject: [PATCH] Adding Unpack Node (#3216) Added unpack node in pure_arm_compute Signed-off-by: Vishal Keshav --- runtimes/pure_arm_compute/src/compilation.cc | 8 +++ .../pure_arm_compute/src/internal/op/NodeVisitor.h | 2 + .../pure_arm_compute/src/internal/op/Unpack.cc | 66 ++++++++++++++++++++ runtimes/pure_arm_compute/src/internal/op/Unpack.h | 71 ++++++++++++++++++++++ runtimes/pure_arm_compute/src/model.cc | 11 ++++ 5 files changed, 158 insertions(+) create mode 100644 runtimes/pure_arm_compute/src/internal/op/Unpack.cc create mode 100644 runtimes/pure_arm_compute/src/internal/op/Unpack.h diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 196451c..14e9452 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -551,6 +551,7 @@ public: void visit(const ::internal::tflite::op::SquaredDifference::Node &node) override; void visit(const ::internal::tflite::op::LocalResponseNormalization::Node &node) override; void visit(const ::internal::tflite::op::DepthToSpace::Node &node) override; + void visit(const ::internal::tflite::op::Unpack::Node &node) override; void visit(const ::internal::tflite::op::Neg::Node &node) override; void visit(const ::internal::tflite::op::Exp::Node &node) override; void visit(const ::internal::tflite::op::ReduceSum::Node &node) override; @@ -4610,6 +4611,13 @@ void Planner::visit(const ::internal::tflite::op::DepthToSpace::Node &node) _builder.addStage(stage); } +void Planner::visit(const ::internal::tflite::op::Unpack::Node &node) +{ + VERBOSE(Unpack) << "Configure Unpack operation" << std::endl; + + throw std::runtime_error("Not supported, yet"); +} + void Planner::visit(const ::internal::tflite::op::Pack::Node &node) { VERBOSE(Pack) << "Configure Pack operation" << std::endl; diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index 8dc5269..5b56e7c 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -66,6 +66,7 @@ #include "internal/op/SquaredDifference.h" #include "internal/op/LocalResponseNormalization.h" #include "internal/op/DepthToSpace.h" +#include "internal/op/Unpack.h" #include "internal/op/Neg.h" #include "internal/op/Exp.h" #include "internal/op/ReduceSum.h" @@ -365,6 +366,7 @@ struct NodeVisitor virtual void visit(const SquaredDifference::Node &) = 0; virtual void visit(const LocalResponseNormalization::Node &) = 0; virtual void visit(const DepthToSpace::Node &) = 0; + virtual void visit(const Unpack::Node &) = 0; virtual void visit(const Neg::Node &) = 0; virtual void visit(const Exp::Node &) = 0; virtual void visit(const ReduceSum::Node &) = 0; diff --git a/runtimes/pure_arm_compute/src/internal/op/Unpack.cc b/runtimes/pure_arm_compute/src/internal/op/Unpack.cc new file mode 100644 index 0000000..fedc22c --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Unpack.cc @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/op/Unpack.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Unpack +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace Unpack +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Unpack +{ +// There are two inputs: tensor which is to be unpacked +// and axis along which tensor needs to be unpacked. + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(inputCount == 2); + + ifm_index = inputs[0]; + + for (uint32_t n = 0; n < outputCount; ++n) + { + ofm_indexes.emplace_back(outputs[n]); + } + axis_index = inputs[1]; +} + +} // namespace Unpack +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/Unpack.h b/runtimes/pure_arm_compute/src/internal/op/Unpack.h new file mode 100644 index 0000000..85592e3 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Unpack.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __INTERNAL_OP_UNPACK_H__ +#define __INTERNAL_OP_UNPACK_H__ + +#include "internal/op/Node.h" + +#include +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Unpack +{ + +struct Param +{ + int32_t ifm_index; + int32_t axis_index; + // There are N outputs after Unpacking Input Tensor along axis + std::vector ofm_indexes; + + Param() = default; + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +class Node final : public op::Node +{ +public: + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + virtual ~Node() = default; + +public: + const Param ¶m(void) const { return _param; } + +public: + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace Unpack +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_UNPACK_H__ diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index 6a33124..bfa3086 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -846,6 +846,17 @@ int ANeuralNetworksModel_addOperationEx(ANeuralNetworksModel *model, break; } + case ANEURALNETWORKS_UNPACK_EX: + { + using internal::tflite::op::Unpack::Param; + using internal::tflite::op::Unpack::Node; + + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } case ANEURALNETWORKS_NEG_EX: { using internal::tflite::op::Neg::Param; -- 2.7.4