#include "internal/layers/HashtableLookupLayer.h"
#include <arm_compute/runtime/CL/CLScheduler.h>
+#include <map>
+#include <cstring>
void HashtableLookupLayer::configure(::arm_compute::ITensor *lookups, ::arm_compute::ITensor *keys,
::arm_compute::ITensor *values, ::arm_compute::ITensor *output,
_values = values;
_output = output;
_hits = hits;
+ _lookup_indices.resize(lookups->info()->dimension(0), -1);
}
void HashtableLookupLayer::run()
const int32_t *lookups_buf = reinterpret_cast<int32_t *>(_lookups->buffer());
const int32_t *keys_buf = reinterpret_cast<int32_t *>(_keys->buffer());
- const auto values_buf = _values->buffer();
- auto output_buf = _output->buffer();
uint8_t *hits_buf = reinterpret_cast<uint8_t *>(_hits->buffer());
const auto lookups_info = _lookups->info();
const auto keys_info = _keys->info();
const auto output_info = _output->info();
- const size_t num_rows = values_info->dimension(1);
- const size_t row_bytes = values_info->total_size() / num_rows;
+ // NOTE The first dimension's position must be always at the end of dimensions.
+ const auto first_dim_pos = values_info->num_dimensions() - 1;
+ const size_t first_dim = values_info->dimension(first_dim_pos);
- int number_of_keys = keys_info->dimension(0);
+ std::map<int32_t, size_t> key_map;
+ const int keys_num = keys_info->dimension(0);
+ for (size_t key_index = 0; key_index < keys_num; key_index++)
+ {
+ key_map[keys_buf[key_index]] = key_index;
+ }
- for (size_t i = 0; i < lookups_info->dimension(0); ++i)
+ const int lookups_num = lookups_info->dimension(0);
+ for (size_t i = 0; i < lookups_num; ++i)
{
- int idx = -1;
- auto lookup_value = reinterpret_cast<const int32_t *>(lookups_buf) + i;
- for (int key_index = 0; key_index < number_of_keys; key_index++)
+ const auto lookup_value = lookups_buf[i];
+ const auto it = key_map.find(lookup_value);
+ if (it != key_map.end())
{
- auto current_key = reinterpret_cast<const int32_t *>(keys_buf) + key_index;
- if (*lookup_value == *current_key)
- {
- idx = key_index;
- break;
- }
+ if (it->second >= first_dim)
+ throw std::runtime_error("HashTable Lookup: index out of bounds.");
+ _lookup_indices[i] = it->second;
}
+ }
- if (idx >= num_rows || idx < 0) // Miss
- {
- size_t row_offset_by_idx = values_info->offset_element_in_bytes({0, idx});
- size_t row_offset_by_i = output_info->offset_element_in_bytes({0, i});
-
- unsigned char *sink_addr = output_buf + row_offset_by_i;
- memset(sink_addr, 0, row_bytes);
+ // If each strides of values and output are different, applied padding size of the two tensors are
+ // different, therefore, it can not be copied at once.
+ auto can_copy_at_once = [&]() -> bool {
+ const auto &values_strides = values_info->strides_in_bytes();
+ const auto &output_strides = output_info->strides_in_bytes();
- hits_buf[i] = 0;
- }
- else // Hit
+ for (size_t i = 0; i < first_dim_pos; ++i)
{
- size_t row_offset_by_idx = values_info->offset_element_in_bytes({0, idx});
- size_t row_offset_by_i = output_info->offset_element_in_bytes({0, i});
+ if (values_strides[i] != values_strides[i])
+ return false;
+ }
- unsigned char *sink_addr = output_buf + row_offset_by_i;
- unsigned char *source_addr = values_buf + row_offset_by_idx;
- memcpy(sink_addr, source_addr, row_bytes);
+ return true;
+ };
- hits_buf[i] = 1;
- }
+ using ::arm_compute::Window;
+ using ::arm_compute::Iterator;
+ using ::arm_compute::Coordinates;
+
+ size_t copy_bytes;
+ Window window;
+ if (can_copy_at_once())
+ {
+ copy_bytes = values_info->total_size() / first_dim;
+ window.use_tensor_dimensions(output_info->tensor_shape(), first_dim_pos);
}
+ else
+ {
+ copy_bytes = values_info->dimension(0) * values_info->element_size();
+ window.use_tensor_dimensions(output_info->tensor_shape(), Window::DimY);
+ }
+
+ Iterator it(_output, window);
+ execute_window_loop(window,
+ [&](const Coordinates &id) {
+ Coordinates values_id = id;
+ const int idx = id[first_dim_pos];
+ const int lookup_index = _lookup_indices[idx];
+ if (lookup_index >= 0)
+ {
+ values_id.set(first_dim_pos, lookup_index);
+ memcpy(it.ptr(), _values->ptr_to_element(values_id), copy_bytes);
+ hits_buf[lookup_index] = 1;
+ }
+ else
+ {
+ memset(it.ptr(), 0, copy_bytes);
+ hits_buf[lookup_index] = 0;
+ }
+ },
+ it);
if (::internal::arm_compute::isGpuMode())
{