Enable Exp op for ACL neon (#7093)
author장지섭/On-Device Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Mon, 2 Sep 2019 06:14:04 +0000 (15:14 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Mon, 2 Sep 2019 06:14:04 +0000 (15:14 +0900)
This commit enables to support Exp op for ACL neon.

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/neurun/backend/acl_neon/KernelGenerator.cc
runtimes/neurun/backend/acl_neon/ShapeFixer.cc
runtimes/neurun/backend/acl_neon/ShapeFixer.h
runtimes/neurun/core/src/compiler/OperationValidator.cc
runtimes/neurun/core/src/compiler/OperationValidator.h
tests/nnapi/nnapi_gtest.skip.armv7l-linux.acl_neon

index a3d627b..be93fed 100644 (file)
@@ -1040,8 +1040,23 @@ void KernelGenerator::visit(const model::operation::DivNode &node)
 
 void KernelGenerator::visit(const model::operation::ExpNode &node)
 {
-  (void)node;
-  throw std::runtime_error("Not supported, yet");
+  const auto output_index{node.getOutputs().at(0)};
+  const auto input_index{node.getInputs().at(model::operation::ExpNode::Input::INPUT)};
+
+  auto output_alloc = _tensor_builder->at(output_index).get();
+  auto input_alloc = _tensor_builder->at(input_index).get();
+
+  std::unique_ptr<::arm_compute::IFunction> fn;
+
+  auto l = nnfw::cpp14::make_unique<::arm_compute::NEExpLayer>();
+
+  l->configure(input_alloc->handle(), output_alloc->handle());
+
+  fn = std::move(l);
+
+  auto acl_fn = asAclFunction(std::move(fn));
+
+  _execution_builder->append(std::move(acl_fn));
 }
 
 void KernelGenerator::visit(const model::operation::ReduceMaxNode &node)
index e6fdca7..96a3520 100644 (file)
@@ -79,6 +79,8 @@ void ShapeFixer::visit(const model::operation::ConcatNode &node)
     _tensor_builder->dimCorrection(inputs, false);
 }
 
+void ShapeFixer::visit(const model::operation::ExpNode &) { /* DO NOTHING */}
+
 void ShapeFixer::visit(const model::operation::FullyConnectedNode &node)
 {
   using model::operation::FullyConnectedNode;
index 392df91..336179c 100644 (file)
@@ -44,6 +44,7 @@ public:
   void visit(const model::operation::MeanNode &) override;
   void visit(const model::operation::AvgPool2DNode &) override;
   void visit(const model::operation::ConcatNode &) override;
+  void visit(const model::operation::ExpNode &) override;
   void visit(const model::operation::FullyConnectedNode &) override;
   void visit(const model::operation::MulNode &) override;
   void visit(const model::operation::ReLUNode &) override;
index ee5ebfa..2677157 100644 (file)
@@ -338,6 +338,18 @@ void OperationValidator::visit(const model::operation::EmbeddingLookupNode &node
   }
 }
 
+void OperationValidator::visit(const model::operation::ExpNode &node)
+{
+  const auto output_index{node.getOutputs().at(0)};
+  const auto input_index{node.getInputs().at(model::operation::ExpNode::Input::INPUT)};
+
+  UNUSED_RELEASE(output_index);
+  UNUSED_RELEASE(input_index);
+
+  assert(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
+  assert(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
+}
+
 void OperationValidator::visit(const model::operation::HashtableLookupNode &node)
 {
   const auto output_index{
index 99b9a02..96dc1b8 100644 (file)
@@ -50,6 +50,7 @@ public:
   void visit(const model::operation::RNNNode &node) override;
   void visit(const model::operation::SpaceToDepthNode &node) override;
   void visit(const model::operation::EmbeddingLookupNode &node) override;
+  void visit(const model::operation::ExpNode &node) override;
   void visit(const model::operation::HashtableLookupNode &node) override;
   void visit(const model::operation::TransposeConvNode &node) override;
   void visit(const model::operation::GatherNode &node) override;
index e87aba3..143108e 100644 (file)
@@ -9,7 +9,6 @@ GeneratedTests.dequantize
 GeneratedTests.embedding_lookup
 GeneratedTests.embedding_lookup_2d_nnfw
 GeneratedTests.embedding_lookup_4d_nnfw
-GeneratedTests.exp_ex*
 GeneratedTests.floor_
 GeneratedTests.greater_equal_ex*
 GeneratedTests.hashtable_lookup*
@@ -50,3 +49,6 @@ generatedtests.logical_not_ex*
 # Need to be fixed
 GeneratedTests.logical_not_ex_1D
 GeneratedTests.logical_not_ex_4D
+# Float error
+GeneratedTests.exp_ex_1D_float
+GeneratedTests.exp_ex_2D_float