Apply CL Kernel of Embedding_Lookup to PACL (#3616)
author장지섭/동작제어Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Mon, 19 Nov 2018 07:44:12 +0000 (16:44 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Mon, 19 Nov 2018 07:44:12 +0000 (16:44 +0900)
This commit applies CL Kernel of Embedding_Lookup to PACL.

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc

index fdca888..3dc9e72 100644 (file)
@@ -59,6 +59,7 @@
 #include <arm_compute/runtime/CL/functions/CLNormalizationLayerEx.h>
 #include <arm_compute/runtime/CL/functions/CLExp.h>
 #include <arm_compute/runtime/CL/functions/CLBatchToSpaceND.h>
+#include <arm_compute/runtime/CL/functions/CLEmbeddingLookup.h>
 #include <arm_compute/runtime/CL/functions/CLEqual.h>
 #include <arm_compute/runtime/CL/functions/CLSquaredDifference.h>
 #include <arm_compute/runtime/CL/functions/CLNeg.h>
@@ -4635,11 +4636,27 @@ void Planner::visit(const ::internal::tflite::op::EmbeddingLookup::Node &node)
     auto lookups_alloc = ctx.at(::internal::tflite::operand::Index{param.lookups_index});
     auto values_alloc = ctx.at(::internal::tflite::operand::Index{param.values_index});
 
-    auto fn = nnfw::make_unique<SimpleEmbeddingLookup>();
+    if (from_env<bool>(std::getenv("USE_SIMPLE_EMBEDDINGLOOKUP")))
+    {
+      auto fn = nnfw::make_unique<SimpleEmbeddingLookup>();
+
+      fn->configure(lookups_alloc, values_alloc, output_alloc);
+
+      builder.append("EmbeddingLookup", std::move(fn));
+    }
+    else if (::internal::arm_compute::isGpuMode())
+    {
+      auto fn = nnfw::make_unique<::arm_compute::CLEmbeddingLookup>();
 
-    fn->configure(lookups_alloc, values_alloc, output_alloc);
+      fn->configure(CAST_CL(values_alloc), CAST_CL(output_alloc), CAST_CL(lookups_alloc));
 
-    builder.append("EmbeddingLookup", std::move(fn));
+      builder.append("EmbeddingLookup", std::move(fn));
+    }
+    else
+    {
+      // TODO Enable NEON Support
+      throw std::runtime_error("Not supported, yet");
+    }
   };
 
   _builder.addStage(stage);