From a713c6e9dbce1cc7f197a882e18a4f2c162e9aa5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Devansh=20Bansal/System=20SW=20/SRI-Bangalore/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 5 Oct 2018 11:35:26 +0530 Subject: [PATCH] Introduce TRANSPOSE_CONV operation on runtime (#2941) This patch introduces TRANSPOSE_CONV operation on runtime. This operator is required by grocery recognition system model. Signed-off-by: b.devansh --- runtimes/pure_arm_compute/src/compilation.cc | 8 +++ .../pure_arm_compute/src/internal/op/NodeVisitor.h | 2 + .../src/internal/op/TransposeConv.cc | 67 ++++++++++++++++++++ .../src/internal/op/TransposeConv.h | 71 ++++++++++++++++++++++ runtimes/pure_arm_compute/src/model.cc | 11 ++++ 5 files changed, 159 insertions(+) create mode 100644 runtimes/pure_arm_compute/src/internal/op/TransposeConv.cc create mode 100644 runtimes/pure_arm_compute/src/internal/op/TransposeConv.h diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 27c7102..c3996ef 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -532,6 +532,7 @@ public: void visit(const ::internal::tflite::op::Exp::Node &node) override; void visit(const ::internal::tflite::op::ReduceSum::Node &node) override; void visit(const ::internal::tflite::op::Equal::Node &node) override; + void visit(const ::internal::tflite::op::TransposeConv::Node &node) override; private: const ::internal::tflite::operand::Set &_ctx; @@ -3372,6 +3373,13 @@ void Planner::visit(const ::internal::tflite::op::Equal::Node &node) throw std::runtime_error("Not supported, yet"); } +void Planner::visit(const ::internal::tflite::op::TransposeConv::Node &node) +{ + VERBOSE(TransposeConv) << "Configure TransposeConv operation" << std::endl; + + throw std::runtime_error("Not supported, yet"); +} + void Planner::visit(const ::internal::tflite::op::SquaredDifference::Node &node) { const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index}; diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index 89960f3..c7464cd 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -63,6 +63,7 @@ #include "internal/op/Exp.h" #include "internal/op/ReduceSum.h" #include "internal/op/Equal.h" +#include "internal/op/TransposeConv.h" namespace internal { @@ -126,6 +127,7 @@ struct NodeVisitor virtual void visit(const Exp::Node &) = 0; virtual void visit(const ReduceSum::Node &) = 0; virtual void visit(const Equal::Node &) = 0; + virtual void visit(const TransposeConv::Node &) = 0; }; } // namespace op diff --git a/runtimes/pure_arm_compute/src/internal/op/TransposeConv.cc b/runtimes/pure_arm_compute/src/internal/op/TransposeConv.cc new file mode 100644 index 0000000..7cbff4f --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/TransposeConv.cc @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/op/TransposeConv.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace TransposeConv +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace TransposeConv +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace TransposeConv +{ + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(inputCount == 3 && outputCount == 1); + + ofm_index = outputs[0]; + + // Each input should be interpreted as follows: + // + // 0 -> Output Shape Index + // 1 -> Weights Index + // 2 -> Input Tensor Index + op_shape_index = inputs[0]; + weights_index = inputs[1]; + ifm_index = inputs[2]; +} + +} // namespace TransposeConv +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/TransposeConv.h b/runtimes/pure_arm_compute/src/internal/op/TransposeConv.h new file mode 100644 index 0000000..4cc3088 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/TransposeConv.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __INTERNAL_OP_TRANSPOSECONV_H__ +#define __INTERNAL_OP_TRANSPOSECONV_H__ + +#include "internal/op/Node.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace TransposeConv +{ + +struct Param +{ + int32_t ofm_index; + + int32_t ifm_index; + int32_t op_shape_index; + int32_t weights_index; + + Param() = default; + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +class Node final : public op::Node +{ +public: + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + virtual ~Node() = default; + +public: + const Param ¶m(void) const { return _param; } + +public: + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace TransposeConv +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_TRANSPOSECONV_H__ diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index 9a66251..d37fdff 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -709,6 +709,17 @@ int ANeuralNetworksModel_addOperation(ANeuralNetworksModel *model, break; } + case ANEURALNETWORKS_TRANSPOSE_CONV_EX: + { + using internal::tflite::op::TransposeConv::Param; + using internal::tflite::op::TransposeConv::Node; + + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } default: throw std::runtime_error{"Not supported operation"}; }; -- 2.7.4