From 17a5748f453a8b549616b6e28cd2f05b33af89b1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Prasanna=20R/System=20SW=20/SRI-Bangalore/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 31 Oct 2018 13:38:18 +0530 Subject: [PATCH] Introduce SQRT in PACL (#3412) This patch introduces SQRT op in PACL Related issue: #3341 Signed-off-by: prasannar --- runtimes/pure_arm_compute/src/compilation.cc | 8 ++ .../pure_arm_compute/src/internal/op/NodeVisitor.h | 7 ++ runtimes/pure_arm_compute/src/internal/op/SQRT.cc | 62 ++++++++++++ runtimes/pure_arm_compute/src/internal/op/SQRT.h | 105 +++++++++++++++++++++ runtimes/pure_arm_compute/src/model.cc | 12 +++ 5 files changed, 194 insertions(+) create mode 100644 runtimes/pure_arm_compute/src/internal/op/SQRT.cc create mode 100644 runtimes/pure_arm_compute/src/internal/op/SQRT.h diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 5727846..5b627ab 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -539,6 +539,7 @@ public: void visit(const ::internal::tflite::op::Split::Node &node) override; void visit(const ::internal::tflite::op::ArgMax::Node &node) override; void visit(const ::internal::tflite::op::RSQRT::Node &node) override; + void visit(const ::internal::tflite::op::SQRT::Node &node) override; void visit(const ::internal::tflite::op::Pad::Node &node) override; void visit(const ::internal::tflite::op::SpaceToDepth::Node &node) override; void visit(const ::internal::tflite::op::SpaceToBatchND::Node &node) override; @@ -3504,6 +3505,13 @@ void Planner::visit(const ::internal::tflite::op::ArgMax::Node &node) throw std::runtime_error("Not supported, yet"); } +void Planner::visit(const ::internal::tflite::op::SQRT::Node &node) +{ + VERBOSE(SQRT) << "Configure SQRT operation" << std::endl; + + throw std::runtime_error("Not supported, yet"); +} + void Planner::visit(const ::internal::tflite::op::RSQRT::Node &node) { VERBOSE(RSQRT) << "Configure Rsqrt operation" << std::endl; diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index 5b56e7c..6f8b234 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -56,6 +56,7 @@ #include "internal/op/Floor.h" #include "internal/op/Split.h" #include "internal/op/RSQRT.h" +#include "internal/op/SQRT.h" #include "internal/op/Pad.h" #include "internal/op/SpaceToDepth.h" #include "internal/op/SpaceToBatchND.h" @@ -316,6 +317,12 @@ struct NodeVisitor */ virtual void visit(const RSQRT::Node &) = 0; /** + * @brief Visit an SQRT node + * @param[in] node SQRT node to visit + * @return N/A + */ + virtual void visit(const SQRT::Node &) = 0; + /** * @brief Visit a Pad node * @param[in] node Pad node to visit * @return N/A diff --git a/runtimes/pure_arm_compute/src/internal/op/SQRT.cc b/runtimes/pure_arm_compute/src/internal/op/SQRT.cc new file mode 100644 index 0000000..70ce42e --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/SQRT.cc @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/op/SQRT.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace SQRT +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace SQRT +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace SQRT +{ + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(inputCount == 1 && outputCount == 1); + + output_index = outputs[0]; + + // Each input should be interpreted as follows: + // 0 -> input Tensor Index + input_index = inputs[0]; +} + +} // namespace SQRT +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/SQRT.h b/runtimes/pure_arm_compute/src/internal/op/SQRT.h new file mode 100644 index 0000000..85dfb97 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/SQRT.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file SQRT.h + * @ingroup COM_AI_RUNTIME + * @brief This file defines internal::tflite::op::SQRT::Param struct + * and internal::tflite::op::SQRT::Node class + */ +#ifndef __INTERNAL_OP_SQRT_H__ +#define __INTERNAL_OP_SQRT_H__ + +#include "internal/op/Node.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace SQRT +{ + +/** + * @brief Struct to have indexes for operation parameter + */ +struct Param +{ + int32_t output_index; /**< Index of output feature map */ + + int32_t input_index; /**< Index of input feature map */ + /** + * @brief Construct as default + */ + Param() = default; + /** + * @brief Construct a new Param object with params + * @param[in] inputCount Count of inputs + * @param[in] inputs Pointer of inputs + * @param[in] outputCount Count of outputs + * @param[in] outputs Pointer of outputs + */ + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +/** + * @brief Class to represent an operation of data structure + */ +class Node final : public op::Node +{ +public: + /** + * @brief Construct a new Node object with param + * @param[in] param Param object that makes up a Node + */ + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + /** + * @brief Destruct as default + */ + virtual ~Node() = default; + +public: + /** + * @brief Get a reference of Param object + * @return Reference of Param object + */ + const Param ¶m(void) const { return _param; } + +public: + /** + * @brief Visit this Node by NodeVisitor + * @return N/A + */ + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace SQRT +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_SQRT_H__ diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index bfa3086..74368bb 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -786,6 +786,18 @@ int ANeuralNetworksModel_addOperationEx(ANeuralNetworksModel *model, break; } + case ANEURALNETWORKS_SQRT_EX: + { + using internal::tflite::op::SQRT::Param; + using internal::tflite::op::SQRT::Node; + + // Add 'operations' + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } case ANEURALNETWORKS_EQUAL_EX: { using internal::tflite::op::Equal::Param; -- 2.7.4