Remove duplicated codes from SimpleEmbeddingLookup (#3159)
author장지섭/동작제어Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Tue, 16 Oct 2018 01:08:57 +0000 (10:08 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Tue, 16 Oct 2018 01:08:57 +0000 (10:08 +0900)
This commit removes duplicated codes from SimpleEmbeddingLookup.

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/pure_arm_compute/src/internal/layers/SimpleEmbeddingLookup.cc

index 30aa858..5f6069c 100644 (file)
@@ -21,7 +21,8 @@ void SimpleEmbeddingLookup::configure(::arm_compute::ITensor *lookups,
                                       ::arm_compute::ITensor *values,
                                       ::arm_compute::ITensor *output)
 {
-  // Assume that verification of operands are already done at Planner::visit()
+  assert(values->info()->num_dimensions() == output->info()->num_dimensions());
+  assert(values->info()->num_dimensions() > 1 && values->info()->num_dimensions() <= 4);
   _lookups = lookups;
   _values = values;
   _output = output;
@@ -47,76 +48,30 @@ void SimpleEmbeddingLookup::run()
   const auto values_info = _values->info();
   const auto output_info = _output->info();
 
-  // TODO Refactor below duplicated code!
-  const auto values_rank = values_info->num_dimensions();
-  switch (values_rank)
+  // NOTE The first dimension's position is always at the end of dimensions.
+  const auto first_dim_pos = values_info->num_dimensions() - 1;
+  ::arm_compute::Coordinates offset_coord{};
+  for (size_t i = 0; i < first_dim_pos; ++i)
   {
-    case 2:
-      // (H,W) in nnapi -> (W,H) in acl
-      {
-        const size_t row_size = values_info->dimension(1);
-        const size_t row_bytes = values_info->total_size() / row_size;
-        for (size_t i = 0; i < lookups_info->dimension(0); ++i)
-        {
-          if (lookups_buf[i] < 0 || lookups_buf[i] >= row_size)
-            throw std::runtime_error("Embedding Lookup: index out of bounds.");
-
-          size_t idx = lookups_buf[i];
-          size_t row_offset_by_idx = values_info->offset_element_in_bytes({0, idx});
-          size_t row_offset_by_i = output_info->offset_element_in_bytes({0, i});
-
-          unsigned char *sink_addr = output_buf + row_offset_by_i;
-          unsigned char *source_addr = values_buf + row_offset_by_idx;
-          memcpy(sink_addr, source_addr, row_bytes);
-        }
-      }
-      break;
-    case 3:
-      // (B,H,W) in nnapi -> (W,H,B) in acl
-      {
-        const size_t row_size = values_info->dimension(2);
-        const size_t row_bytes = values_info->total_size() / row_size;
-        for (size_t i = 0; i < lookups_info->dimension(0); ++i)
-        {
-          if (lookups_buf[i] < 0 || lookups_buf[i] >= row_size)
-            throw std::runtime_error("Embedding Lookup: index out of bounds.");
-
-          size_t idx = lookups_buf[i];
-          size_t row_offset_by_idx = values_info->offset_element_in_bytes({0, 0, idx});
-          size_t row_offset_by_i = output_info->offset_element_in_bytes({0, 0, i});
+    offset_coord.set(i, 0);
+  }
 
-          unsigned char *sink_addr = output_buf + row_offset_by_i;
-          unsigned char *source_addr = values_buf + row_offset_by_idx;
-          memcpy(sink_addr, source_addr, row_bytes);
-        }
-      }
-      break;
-    case 4:
-      // (N,H,W,C) in nnapi -> (N,C,H,W) in acl
-      {
-        const size_t row_size = values_info->dimension(3);
-        const size_t row_bytes = values_info->total_size() / row_size;
-        for (size_t i = 0; i < lookups_info->dimension(0); ++i)
-        {
-          if (lookups_buf[i] < 0 || lookups_buf[i] >= row_size)
-            throw std::runtime_error("Embedding Lookup: index out of bounds.");
+  const size_t first_dim = values_info->dimension(first_dim_pos);
+  const size_t copy_bytes = values_info->total_size() / first_dim;
+  for (size_t i = 0; i < lookups_info->dimension(0); ++i)
+  {
+    if (lookups_buf[i] < 0 || lookups_buf[i] >= first_dim)
+      throw std::runtime_error("Embedding Lookup: index out of bounds.");
 
-          size_t idx = lookups_buf[i];
-          size_t row_offset_by_idx = values_info->offset_element_in_bytes({0, 0, 0, idx});
-          size_t row_offset_by_i = output_info->offset_element_in_bytes({0, 0, 0, i});
+    size_t idx = lookups_buf[i];
+    offset_coord.set(first_dim_pos, idx);
+    size_t values_offset = values_info->offset_element_in_bytes(offset_coord);
+    offset_coord.set(first_dim_pos, i);
+    size_t output_offset = output_info->offset_element_in_bytes(offset_coord);
 
-          unsigned char *sink_addr = output_buf + row_offset_by_i;
-          unsigned char *source_addr = values_buf + row_offset_by_idx;
-          memcpy(sink_addr, source_addr, row_bytes);
-        }
-      }
-      break;
-    case 1:
-      // In this case, shape of values actually is matrix but the height(row size) is 1 in acl. If
-      // row size is 1, this op is not needed and it means this situtation could be wrong.
-      throw std::runtime_error("Wrong usage of EmbeddingLookup op!");
-    default:
-      throw std::runtime_error("Not supported rank!");
+    unsigned char *sink_addr = output_buf + output_offset;
+    unsigned char *source_addr = values_buf + values_offset;
+    memcpy(sink_addr, source_addr, copy_bytes);
   }
 
   if (::internal::arm_compute::isGpuMode())