From 56f2c6fd007c5880d96b45a58d7c14476e8ae0f2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Prasanna=20R/System=20SW=20/SRI-Bangalore/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 4 Oct 2018 12:28:20 +0530 Subject: [PATCH] Introduce EQUAL_EX operation on runtime (#2872) This patch introduces EQUAL_EX operation on runtime. The related ACL implementation is done in subsequent patches. Signed-off-by: prasannar --- runtimes/pure_arm_compute/src/compilation.cc | 8 +++ runtimes/pure_arm_compute/src/internal/op/Equal.cc | 65 +++++++++++++++++ runtimes/pure_arm_compute/src/internal/op/Equal.h | 83 ++++++++++++++++++++++ .../pure_arm_compute/src/internal/op/NodeVisitor.h | 2 + runtimes/pure_arm_compute/src/model.cc | 12 ++++ 5 files changed, 170 insertions(+) create mode 100644 runtimes/pure_arm_compute/src/internal/op/Equal.cc create mode 100644 runtimes/pure_arm_compute/src/internal/op/Equal.h diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 051681e..7b37fc8 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -531,6 +531,7 @@ public: void visit(const ::internal::tflite::op::Neg::Node &node) override; void visit(const ::internal::tflite::op::Exp::Node &node) override; void visit(const ::internal::tflite::op::ReduceSum::Node &node) override; + void visit(const ::internal::tflite::op::Equal::Node &node) override; private: const ::internal::tflite::operand::Set &_ctx; @@ -3364,6 +3365,13 @@ void Planner::visit(const ::internal::tflite::op::RSQRT::Node &node) throw std::runtime_error("Not supported, yet"); } +void Planner::visit(const ::internal::tflite::op::Equal::Node &node) +{ + VERBOSE(Equal) << "Configure Equal operation" << std::endl; + + throw std::runtime_error("Not supported, yet"); +} + void Planner::visit(const ::internal::tflite::op::SquaredDifference::Node &node) { const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index}; diff --git a/runtimes/pure_arm_compute/src/internal/op/Equal.cc b/runtimes/pure_arm_compute/src/internal/op/Equal.cc new file mode 100644 index 0000000..99e3a56 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Equal.cc @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/op/Equal.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Equal +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace Equal +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Equal +{ + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(inputCount == 2 && outputCount == 1); + + output_index = outputs[0]; + + // Each input should be interpreted as follows: + // + // 0 -> Input Tensor Index + // 1 -> Block size Index + input1_index = inputs[0]; + input2_index = inputs[1]; +} + +} // namespace Equal +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/Equal.h b/runtimes/pure_arm_compute/src/internal/op/Equal.h new file mode 100644 index 0000000..78b9f84 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Equal.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __INTERNAL_OP_EQUAL_H__ +#define __INTERNAL_OP_EQUAL_H__ + +#include "internal/op/Node.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Equal +{ + +struct Param +{ + int32_t output_index; + + int32_t input1_index; + int32_t input2_index; + + Param() = default; + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +} // namespace Equal +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Equal +{ +class Node final : public op::Node +{ +public: + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + virtual ~Node() = default; + +public: + const Param ¶m(void) const { return _param; } + +public: + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace Equal +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_EQUAL_H__ diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index 92c9f82..89960f3 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -62,6 +62,7 @@ #include "internal/op/Neg.h" #include "internal/op/Exp.h" #include "internal/op/ReduceSum.h" +#include "internal/op/Equal.h" namespace internal { @@ -124,6 +125,7 @@ struct NodeVisitor virtual void visit(const Neg::Node &) = 0; virtual void visit(const Exp::Node &) = 0; virtual void visit(const ReduceSum::Node &) = 0; + virtual void visit(const Equal::Node &) = 0; }; } // namespace op diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index 69febaf..9a66251 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -764,6 +764,18 @@ int ANeuralNetworksModel_addOperationEx(ANeuralNetworksModel *model, break; } + case ANEURALNETWORKS_EQUAL_EX: + { + using internal::tflite::op::Equal::Param; + using internal::tflite::op::Equal::Node; + + // Add 'operations' + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } case ANEURALNETWORKS_SQUARED_DIFFERENCE_EX: { using internal::tflite::op::SquaredDifference::Param; -- 2.7.4