From: PRAVEEN DORESWAMY NAIDU/System SW /SRI-Bangalore/Staff Engineer/삼성전자 Date: Thu, 25 Oct 2018 10:51:14 +0000 (+0530) Subject: Introduce ArgMax operation. (#3288) X-Git-Tag: 0.3~519 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=37609e5e8c5d1a7707f84bd8da38c84948ae5fe6;p=platform%2Fcore%2Fml%2Fnnfw.git Introduce ArgMax operation. (#3288) This patch will introduce the ArgMax operation in NNFW. Signed-off-by: Praveen D N --- diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 9071c8b..3b927bc 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -537,6 +537,7 @@ public: void visit(const ::internal::tflite::op::LSTM::Node &node) override; void visit(const ::internal::tflite::op::Floor::Node &node) override; void visit(const ::internal::tflite::op::Split::Node &node) override; + void visit(const ::internal::tflite::op::ArgMax::Node &node) override; void visit(const ::internal::tflite::op::RSQRT::Node &node) override; void visit(const ::internal::tflite::op::Pad::Node &node) override; void visit(const ::internal::tflite::op::SpaceToDepth::Node &node) override; @@ -3497,6 +3498,13 @@ void Planner::visit(const ::internal::tflite::op::Floor::Node &node) _builder.addStage(stage); } +void Planner::visit(const ::internal::tflite::op::ArgMax::Node &node) +{ + VERBOSE(ArgMax) << "Configure ARGMAX operation" << std::endl; + + throw std::runtime_error("Not supported, yet"); +} + void Planner::visit(const ::internal::tflite::op::RSQRT::Node &node) { VERBOSE(RSQRT) << "Configure Rsqrt operation" << std::endl; diff --git a/runtimes/pure_arm_compute/src/internal/op/ArgMax.cc b/runtimes/pure_arm_compute/src/internal/op/ArgMax.cc new file mode 100644 index 0000000..4854303 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/ArgMax.cc @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "internal/op/ArgMax.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace ArgMax +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace ArgMax +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace ArgMax +{ + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(inputCount == 2 && outputCount == 1); + + ofm_index = outputs[0]; + + // Each input should be interpreted as follows: + // + // 0 -> Input Tensor Index + // 1 -> Axis Tensor Index + ifm_index = inputs[0]; + axis_index = inputs[1]; +} + +} // namespace ArgMax +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/ArgMax.h b/runtimes/pure_arm_compute/src/internal/op/ArgMax.h new file mode 100644 index 0000000..780af22 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/ArgMax.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __INTERNAL_OP_ARGMAX_H__ +#define __INTERNAL_OP_ARGMAX_H__ + +#include "internal/op/Node.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace ArgMax +{ + +struct Param +{ + int32_t ofm_index; + + int32_t ifm_index; + int32_t axis_index; + + Param() = default; + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +class Node final : public op::Node +{ +public: + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + virtual ~Node() = default; + +public: + const Param ¶m(void) const { return _param; } + +public: + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace ArgMax +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_ARGMAX_H__ diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index d97bcad..8dc5269 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -32,6 +32,7 @@ #include "internal/op/Dequantize.h" #include "internal/op/MaxPool2D.h" #include "internal/op/AvgPool2D.h" +#include "internal/op/ArgMax.h" #include "internal/op/Concat.h" #include "internal/op/Reshape.h" #include "internal/op/ResizeBilinear.h" @@ -176,6 +177,12 @@ struct NodeVisitor */ virtual void visit(const Concat::Node &) = 0; /** + * @brief Visit a ArgMax node + * @param[in] node ArgMax node to visit + * @return N/A + */ + virtual void visit(const ArgMax::Node &) = 0; + /** * @brief Visit an Reshape node * @param[in] node Reshape node to visit * @return N/A diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index 2853eaf..6a33124 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -905,6 +905,19 @@ int ANeuralNetworksModel_addOperationEx(ANeuralNetworksModel *model, break; } + case ANEURALNETWORKS_ARGMAX_EX: + { + using internal::tflite::op::ArgMax::Param; + using internal::tflite::op::ArgMax::Node; + + // Add 'operations' + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } + default: throw std::runtime_error{"Not supported operation"}; }