[neurun] Merge Equal and NotEqual to Comparison op (#4797)
author김수진/On-Device Lab(SR)/Engineer/삼성전자 <sjsujin.kim@samsung.com>
Wed, 20 Mar 2019 08:37:04 +0000 (17:37 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 20 Mar 2019 08:37:04 +0000 (17:37 +0900)
In https://github.sec.samsung.net/STAR/nnfw/pull/4738#discussion_r156180, we've discussed to group all operations in `::arm_compute::ComparisonOperation`.
This commit introduces `ComparisonNode` that includes operations in `::::arm_compute::ComparisonOperation` and intergrate `Equal` and `NotEqual` to `ComparisonNode`

`ComparisonNode` can be support below.

```
enum class ComparisonType
 {
   Equal,
   NotEqual,
   Greater,
   GreaterEqual,
   Less,
   LessEqual
 };
```

Signed-off-by: sjsujinkim <sjsujin.kim@samsung.com>
runtimes/neurun/backend/acl_cl/StageGenerator.cc
runtimes/neurun/backend/acl_cl/StageGenerator.h
runtimes/neurun/core/include/model/operation/ComparisonNode.h [new file with mode: 0644]
runtimes/neurun/core/include/model/operation/EqualNode.h [deleted file]
runtimes/neurun/core/include/model/operation/Node.Include.h
runtimes/neurun/core/include/model/operation/NotEqualNode.h [deleted file]
runtimes/neurun/core/include/model/operation/Op.lst
runtimes/neurun/core/src/model/operation/ComparisonNode.cc [new file with mode: 0644]
runtimes/neurun/core/src/model/operation/EqualNode.cc [deleted file]
runtimes/neurun/core/src/model/operation/NotEqualNode.cc [deleted file]
runtimes/neurun/frontend/nnapi/wrapper/OperationFactory.cc

index 73530aa..4ee36df 100644 (file)
@@ -1610,11 +1610,11 @@ void StageGenerator::visit(const model::operation::ReduceMaxNode &node)
   });
 }
 
-void StageGenerator::visit(const model::operation::NotEqualNode &node)
+void StageGenerator::visit(const model::operation::ComparisonNode &node)
 {
   const auto output_index{node.getOutputs().at(0)};
-  const auto input0_index{node.getInputs().at(model::operation::NotEqualNode::Input::INPUT0)};
-  const auto input1_index{node.getInputs().at(model::operation::NotEqualNode::Input::INPUT1)};
+  const auto input0_index{node.getInputs().at(model::operation::ComparisonNode::Input::INPUT0)};
+  const auto input1_index{node.getInputs().at(model::operation::ComparisonNode::Input::INPUT1)};
 
   if (!(_ctx.at(input0_index).shape() == _ctx.at(input1_index).shape()))
   {
@@ -1635,6 +1635,8 @@ void StageGenerator::visit(const model::operation::NotEqualNode &node)
     model::operand::Index output_index;
     model::operand::Index input0_index;
     model::operand::Index input1_index;
+
+    model::operation::ComparisonNode::ComparisonType comparison_type;
   };
 
   Param param;
@@ -1643,6 +1645,8 @@ void StageGenerator::visit(const model::operation::NotEqualNode &node)
   param.input0_index = input0_index;
   param.input1_index = input1_index;
 
+  param.comparison_type = node.param().comparison_type;
+
   auto tensors = _tensor_builder;
 
   returnStage([tensors, param](IExecutionBuilder &builder) {
@@ -1655,7 +1659,7 @@ void StageGenerator::visit(const model::operation::NotEqualNode &node)
     auto l = make_layer<::arm_compute::CLComparison>();
 
     l->configure(input0_alloc->handle(), input1_alloc->handle(), output_alloc->handle(),
-                 arm_compute::ComparisonOperation::NotEqual);
+                 (arm_compute::ComparisonOperation)param.comparison_type);
 
     fn = std::move(l);
 
@@ -2458,58 +2462,6 @@ void StageGenerator::visit(const model::operation::LogicalNotNode &node)
   });
 }
 
-void StageGenerator::visit(const model::operation::EqualNode &node)
-{
-  const auto output_index{node.getOutputs().at(0)};
-  const auto input0_index{node.getInputs().at(model::operation::EqualNode::Input::INPUT0)};
-  const auto input1_index{node.getInputs().at(model::operation::EqualNode::Input::INPUT1)};
-
-  if (!(_ctx.at(input0_index).shape() == _ctx.at(input1_index).shape()))
-  {
-    const auto broadcast_rank =
-        std::max(_ctx.at(input0_index).shape().rank(), _ctx.at(input1_index).shape().rank());
-    const_cast<neurun::model::operand::Shape &>(_ctx.at(input0_index).shape())
-        .extendRank(broadcast_rank);
-    const_cast<neurun::model::operand::Shape &>(_ctx.at(input1_index).shape())
-        .extendRank(broadcast_rank);
-  }
-
-  // Construct operation parameters
-  struct Param
-  {
-    model::operand::Index output_index;
-    model::operand::Index input0_index;
-    model::operand::Index input1_index;
-  };
-
-  Param param;
-
-  param.output_index = output_index;
-  param.input0_index = input0_index;
-  param.input1_index = input1_index;
-
-  auto tensors = _tensor_builder;
-
-  returnStage([tensors, param](IExecutionBuilder &builder) {
-    auto output_alloc = tensors->at(param.output_index).get();
-    auto input0_alloc = tensors->at(param.input0_index).get();
-    auto input1_alloc = tensors->at(param.input1_index).get();
-
-    std::unique_ptr<::arm_compute::IFunction> fn;
-
-    auto l = make_layer<::arm_compute::CLComparison>();
-
-    l->configure(input0_alloc->handle(), input1_alloc->handle(), output_alloc->handle(),
-                 ::arm_compute::ComparisonOperation::Equal);
-
-    fn = std::move(l);
-
-    auto acl_fn = make_cl_function(std::move(fn));
-
-    builder.append(std::move(acl_fn));
-  });
-}
-
 void StageGenerator::visit(const model::operation::SquaredDifferenceNode &node)
 {
   const auto ofm_index{node.getOutputs().at(0)};
index 5dad5c7..e16699b 100644 (file)
@@ -58,7 +58,7 @@ public:
   virtual void visit(const model::operation::ExpNode &) override;
   virtual void visit(const model::operation::LogisticNode &) override;
   virtual void visit(const model::operation::ReduceMaxNode &) override;
-  virtual void visit(const model::operation::NotEqualNode &) override;
+  virtual void visit(const model::operation::ComparisonNode &) override;
   virtual void visit(const model::operation::LogicalAndNode &) override;
   virtual void visit(const model::operation::RSQRTNode &) override;
   virtual void visit(const model::operation::ReLUNode &) override;
@@ -76,7 +76,6 @@ public:
   virtual void visit(const model::operation::SQRTNode &) override;
   virtual void visit(const model::operation::LogicalOrNode &) override;
   virtual void visit(const model::operation::LogicalNotNode &) override;
-  virtual void visit(const model::operation::EqualNode &) override;
   virtual void visit(const model::operation::SquaredDifferenceNode &) override;
   virtual void visit(const model::operation::TopKV2Node &) override;
 
diff --git a/runtimes/neurun/core/include/model/operation/ComparisonNode.h b/runtimes/neurun/core/include/model/operation/ComparisonNode.h
new file mode 100644 (file)
index 0000000..8778b6c
--- /dev/null
@@ -0,0 +1,72 @@
+/*
+  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+  *
+  * Licensed under the Apache License, Version 2.0 (the "License");
+  * you may not use this file except in compliance with the License.
+  * You may obtain a copy of the License at
+  *
+  *      http://www.apache.org/licenses/LICENSE-2.0
+  *
+  * Unless required by applicable law or agreed to in writing, software
+  * distributed under the License is distributed on an "AS IS" BASIS,
+  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  * See the License for the specific language governing permissions and
+  * limitations under the License.
+  */
+
+#ifndef __NEURUN_MODEL_OPERATION_COMPARISON_NODE_H__
+#define __NEURUN_MODEL_OPERATION_COMPARISON_NODE_H__
+
+#include "model/operation/Node.h"
+
+namespace neurun
+{
+namespace model
+{
+namespace operation
+{
+
+class ComparisonNode : public model::operation::Node
+{
+public:
+  enum Input
+  {
+    INPUT0 = 0,
+    INPUT1
+  };
+
+  enum class ComparisonType
+  {
+    Equal,
+    NotEqual,
+    Greater,
+    GreaterEqual,
+    Less,
+    LessEqual
+  };
+
+  struct Param
+  {
+    ComparisonType comparison_type;
+  };
+
+public:
+  ComparisonNode(const operand::IndexSet &inputs, const operand::IndexSet &outputs,
+                 const Param &param);
+
+public:
+  virtual void accept(NodeVisitor &&) const override;
+  virtual std::string getName() const override { return "Comparison"; }
+
+public:
+  const Param &param() const { return _param; }
+
+private:
+  Param _param;
+};
+
+} // namespace operation
+} // namespace model
+} // namespace neurun
+
+#endif // __NEURUN_MODEL_OPERATION_COMPARISON_NODE_H__
diff --git a/runtimes/neurun/core/include/model/operation/EqualNode.h b/runtimes/neurun/core/include/model/operation/EqualNode.h
deleted file mode 100644 (file)
index a1792f1..0000000
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __NEURUN_MODEL_OPERATION_EQUAL_NODE_H__
-#define __NEURUN_MODEL_OPERATION_EQUAL_NODE_H__
-
-#include "model/operation/Node.h"
-
-namespace neurun
-{
-namespace model
-{
-namespace operation
-{
-
-class EqualNode : public model::operation::Node
-{
-public:
-  enum Input
-  {
-    INPUT0 = 0,
-    INPUT1 = 1
-  };
-
-public:
-  EqualNode(const operand::IndexSet &inputs, const operand::IndexSet &outputs);
-
-public:
-  virtual void accept(NodeVisitor &&) const override;
-  virtual std::string getName() const override { return "Equal"; }
-};
-
-} // namespace operation
-} // namespace model
-} // namespace neurun
-
-#endif // __NEURUN_MODEL_OPERATION_EQUAL_NODE_H__
index 3b23507..8c05b20 100644 (file)
@@ -38,8 +38,7 @@
 #include "DivNode.h"
 #include "ExpNode.h"
 #include "ReduceMaxNode.h"
-#include "NotEqualNode.h"
-#include "EqualNode.h"
+#include "ComparisonNode.h"
 #include "LogicalAndNode.h"
 #include "LogicalOrNode.h"
 #include "LogicalNotNode.h"
diff --git a/runtimes/neurun/core/include/model/operation/NotEqualNode.h b/runtimes/neurun/core/include/model/operation/NotEqualNode.h
deleted file mode 100644 (file)
index baf2349..0000000
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __NEURUN_MODEL_OPERATION_NOT_EQUAL_NODE_H__
-#define __NEURUN_MODEL_OPERATION_NOT_EQUAL_NODE_H__
-
-#include "model/operation/Node.h"
-
-namespace neurun
-{
-namespace model
-{
-namespace operation
-{
-
-class NotEqualNode : public model::operation::Node
-{
-public:
-  enum Input
-  {
-    INPUT0 = 0,
-    INPUT1 = 1
-  };
-
-public:
-  NotEqualNode(const operand::IndexSet &inputs, const operand::IndexSet &outputs);
-
-public:
-  virtual void accept(NodeVisitor &&) const override;
-  virtual std::string getName() const override { return "NotEqual"; }
-};
-
-} // namespace operation
-} // namespace model
-} // namespace neurun
-
-#endif // __NEURUN_MODEL_OPERATION_NOT_EQUAL_NODE_H__
index dea9cb5..1b8b793 100644 (file)
@@ -42,7 +42,7 @@ OP(DivNode                 , true)
 OP(TransposeNode           , true)
 OP(ExpNode                 , true)
 OP(ReduceMaxNode           , true)
-OP(NotEqualNode            , true)
+OP(ComparisonNode          , true)
 OP(LogicalAndNode          , true)
 OP(LogicalOrNode           , true)
 OP(LogicalNotNode          , true)
@@ -60,7 +60,6 @@ OP(HashtableLookupNode     , true)
 OP(PReLUNode               , true)
 OP(TransposeConvNode       , true)
 OP(SQRTNode                , true)
-OP(EqualNode               , true)
 OP(SquaredDifferenceNode   , true)
 OP(TopKV2Node              , true)
 OP(PermuteNode             , false)
diff --git a/runtimes/neurun/core/src/model/operation/ComparisonNode.cc b/runtimes/neurun/core/src/model/operation/ComparisonNode.cc
new file mode 100644 (file)
index 0000000..4643ccd
--- /dev/null
@@ -0,0 +1,40 @@
+/*
+  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+  *
+  * Licensed under the Apache License, Version 2.0 (the "License");
+  * you may not use this file except in compliance with the License.
+  * You may obtain a copy of the License at
+  *
+  *      http://www.apache.org/licenses/LICENSE-2.0
+  *
+  * Unless required by applicable law or agreed to in writing, software
+  * distributed under the License is distributed on an "AS IS" BASIS,
+  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  * See the License for the specific language governing permissions and
+  * limitations under the License.
+  */
+
+#include "model/operation/ComparisonNode.h"
+
+#include <cassert>
+
+#include "model/operation/NodeVisitor.h"
+
+namespace neurun
+{
+namespace model
+{
+namespace operation
+{
+
+void ComparisonNode::accept(NodeVisitor &&v) const { v.visit(*this); }
+
+ComparisonNode::ComparisonNode(const operand::IndexSet &inputs, const operand::IndexSet &outputs,
+                               const Param &param)
+    : model::operation::Node{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
+{
+}
+
+} // namespace operation
+} // namespace model
+} // namespace neurun
diff --git a/runtimes/neurun/core/src/model/operation/EqualNode.cc b/runtimes/neurun/core/src/model/operation/EqualNode.cc
deleted file mode 100644 (file)
index 9004b03..0000000
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "model/operation/EqualNode.h"
-
-#include <cassert>
-
-#include "model/operation/NodeVisitor.h"
-
-namespace neurun
-{
-namespace model
-{
-namespace operation
-{
-
-void EqualNode::accept(NodeVisitor &&v) const { v.visit(*this); }
-
-EqualNode::EqualNode(const operand::IndexSet &inputs, const operand::IndexSet &outputs)
-    : model::operation::Node{OperandConstraint::createExact(2u), inputs, outputs}
-{
-}
-
-} // namespace operation
-} // namespace model
-} // namespace neurun
diff --git a/runtimes/neurun/core/src/model/operation/NotEqualNode.cc b/runtimes/neurun/core/src/model/operation/NotEqualNode.cc
deleted file mode 100644 (file)
index 867184b..0000000
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "model/operation/NotEqualNode.h"
-
-#include <cassert>
-
-#include "model/operation/NodeVisitor.h"
-
-namespace neurun
-{
-namespace model
-{
-namespace operation
-{
-
-void NotEqualNode::accept(NodeVisitor &&v) const { v.visit(*this); }
-
-NotEqualNode::NotEqualNode(const operand::IndexSet &inputs, const operand::IndexSet &outputs)
-    : model::operation::Node{OperandConstraint::createExact(2u), inputs, outputs}
-{
-}
-
-} // namespace operation
-} // namespace model
-} // namespace neurun
index 7846fa8..f238927 100644 (file)
@@ -613,7 +613,10 @@ OperationFactory::OperationFactory()
     //  1 -> input2 Tensor Index
     operand::IndexSet inputs{init_param.inputs[0], init_param.inputs[1]};
 
-    return new operation::NotEqualNode{inputs, outputs};
+    operation::ComparisonNode::Param param;
+    param.comparison_type = operation::ComparisonNode::ComparisonType::NotEqual;
+
+    return new operation::ComparisonNode{inputs, outputs, param};
   };
 
   _map[ANEURALNETWORKS_LOGICAL_AND_EX] = [](const OperationFactory::Param &init_param) {
@@ -921,7 +924,10 @@ OperationFactory::OperationFactory()
     //  1 -> input1 Tensor Index
     operand::IndexSet inputs{init_param.inputs[0], init_param.inputs[1]};
 
-    return new operation::EqualNode{inputs, outputs};
+    operation::ComparisonNode::Param param;
+    param.comparison_type = operation::ComparisonNode::ComparisonType::Equal;
+
+    return new operation::ComparisonNode{inputs, outputs, param};
   };
 
   _map[ANEURALNETWORKS_SQUARED_DIFFERENCE_EX] = [](const OperationFactory::Param &init_param) {