Make to support HashTableLookup op for acl neon (#7694)
author장지섭/On-Device Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Tue, 24 Sep 2019 05:21:01 +0000 (14:21 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Tue, 24 Sep 2019 05:21:01 +0000 (14:21 +0900)
This commit makes to support HashTableLookup op for acl neon.
  - Introduce NEHashTableLookup and NEHashTableLookupKernel
  - Apply NEHashTableLookup layer for neurun

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/libs/ARMComputeEx/arm_compute/core/NEON/kernels/NEHashtableLookupKernel.h [new file with mode: 0644]
runtimes/libs/ARMComputeEx/arm_compute/runtime/NEON/NEFunctionsEx.h
runtimes/libs/ARMComputeEx/arm_compute/runtime/NEON/functions/NEHashtableLookup.h [new file with mode: 0644]
runtimes/libs/ARMComputeEx/src/core/NEON/kernels/NEHashtableLookupKernel.cpp [new file with mode: 0644]
runtimes/libs/ARMComputeEx/src/runtime/NEON/functions/NEHashtableLookup.cpp [new file with mode: 0644]
runtimes/neurun/backend/acl_neon/KernelGenerator.cc
runtimes/neurun/backend/acl_neon/KernelGenerator.h
runtimes/neurun/backend/acl_neon/ShapeFixer.cc
runtimes/neurun/backend/acl_neon/ShapeFixer.h
tests/nnapi/nnapi_gtest.skip.armv7l-linux.acl_neon
tests/scripts/neurun_frameworktest_list.armv7l.acl_neon.txt

diff --git a/runtimes/libs/ARMComputeEx/arm_compute/core/NEON/kernels/NEHashtableLookupKernel.h b/runtimes/libs/ARMComputeEx/arm_compute/core/NEON/kernels/NEHashtableLookupKernel.h
new file mode 100644 (file)
index 0000000..d8976e7
--- /dev/null
@@ -0,0 +1,96 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_NEHASHTABLELOOKUPKERNEL_H__
+#define __ARM_COMPUTE_NEHASHTABLELOOKUPKERNEL_H__
+
+#include "arm_compute/core/NEON/INEKernel.h"
+#include "arm_compute/core/Types.h"
+
+namespace arm_compute
+{
+class ITensor;
+
+/** NEON kernel to perform HashtableLookup operation */
+class NEHashtableLookupKernel : public INEKernel
+{
+public:
+  const char *name() const override { return "NEHashtableLookupKernel"; }
+  /** Default constructor */
+  NEHashtableLookupKernel();
+  /** Prevent instances of this class from being copied (As this class contains pointers). */
+  NEHashtableLookupKernel(const NEHashtableLookupKernel &) = delete;
+  /** Prevent instances of this class from being copied (As this class contains pointers). */
+  NEHashtableLookupKernel &operator=(const NEHashtableLookupKernel &) = delete;
+  /** Allow instances of this class to be moved */
+  NEHashtableLookupKernel(NEHashtableLookupKernel &&) = default;
+  /** Allow instances of this class to be moved */
+  NEHashtableLookupKernel &operator=(NEHashtableLookupKernel &&) = default;
+  /** Initialize the kernel's inputs, outputs.
+   *
+   * @param[in]  lookups  Lookups 1D tensor that values are indices into the first dimension of
+   * input. Data types supported: S32
+   * @param[in]  keys     Keys 1D tensor. keys and input pair represent a map.
+   *                      Data types supported: S32
+   * @param[in]  input    Source tensor.
+   *                      Data types supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32
+   * @param[out] output   Destination tensor. Data types and data layouts supported: Same as @p
+   * input.
+   * @param[out] hits     Hits 1D tensor. A boolean tensor that indicates whether the lookup hits
+   * (True) or not (False). Data types supported: U8/QASYMM8
+   * input.
+   */
+  void configure(const ITensor *lookups, const ITensor *keys, const ITensor *input, ITensor *output,
+                 ITensor *hits);
+  /** Static function to check if given info will lead to a valid configuration of @ref
+   * NEHashtableLookupKernel
+   *
+   * @param[in]  lookups  The lookups tensor info. Data types supported: S32.
+   * @param[in]  keys     The keys tensor info. keys and input pair represent a map.
+   *                      Data types supported: S32
+   * @param[in]  input    The input tensor info.
+   *                      Data types supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32
+   * @param[out] output   The output tensor info. Data types and data layouts supported: Same as @p
+   * input.
+   * @param[out] hits     The hits tensor info. A boolean tensor that indicates whether the lookup
+   * hits (True) or not (False). Data types supported: U8/QASYMM8
+   *
+   * @return a status
+   */
+  static Status validate(const ITensorInfo *lookups, const ITensorInfo *keys,
+                         const ITensorInfo *input, const ITensorInfo *output,
+                         const ITensorInfo *hits);
+
+  // Inherited methods overridden:
+  void run(const Window &window, const ThreadInfo &info) override;
+
+private:
+  const ITensor *_lookups; /** Lookups tensor */
+  const ITensor *_keys;    /** Keys tensor */
+  const ITensor *_input;   /** Source tensor */
+  ITensor *_output;        /** Destination tensor */
+  ITensor *_hits;          /** Hits tensor */
+};
+} // namespace arm_compute
+#endif /*__ARM_COMPUTE_NEHASHTABLELOOKUPKERNEL_H__ */
index c6e80ba..0a3cb11 100644 (file)
@@ -22,6 +22,7 @@
 #include <arm_compute/runtime/NEON/functions/NEElementwiseUnaryLayerEx.h>
 #include <arm_compute/runtime/NEON/functions/NEEmbeddingLookup.h>
 #include <arm_compute/runtime/NEON/functions/NEFullyConnectedReshapingLayer.h>
+#include <arm_compute/runtime/NEON/functions/NEHashtableLookup.h>
 #include <arm_compute/runtime/NEON/functions/NEPReLU.h>
 #include <arm_compute/runtime/NEON/functions/NEReduceMeanEx.h>
 #include <arm_compute/runtime/NEON/functions/NEReduceSum.h>
diff --git a/runtimes/libs/ARMComputeEx/arm_compute/runtime/NEON/functions/NEHashtableLookup.h b/runtimes/libs/ARMComputeEx/arm_compute/runtime/NEON/functions/NEHashtableLookup.h
new file mode 100644 (file)
index 0000000..69abf01
--- /dev/null
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2016-2018 ARM Limited.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * @file NEHashtableLookup.h
+ * @ingroup COM_AI_RUNTIME
+ * @brief This file contains arm_compute::NEHashtableLookup class
+ */
+
+#ifndef __ARM_COMPUTE_NEHASHTABLELOOKUP_H__
+#define __ARM_COMPUTE_NEHASHTABLELOOKUP_H__
+
+#include "arm_compute/runtime/NEON/INESimpleFunctionNoBorder.h"
+
+#include <vector>
+
+namespace arm_compute
+{
+class ITensor;
+
+/**
+ * @brief Class to perform HashtableLookup operation
+ */
+class NEHashtableLookup : public INESimpleFunctionNoBorder
+{
+public:
+  /**
+   * @brief Set the input and output tensors.
+   * @param[in]  lookups  Lookups 1D tensor that values are indices into the first dimension of
+   *                      input. Data types supported: S32
+   * @param[in]  keys     Keys 1D tensor. keys and input pair represent a map.
+   *                      Data types supported: S32
+   * @param[in]  input    Source tensor.
+   *                      Data types supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32
+   * @param[out] output   Destination tensor. Data types and data layouts supported: Same as @p
+   *                      input.
+   * @param[out] hits     Hits 1D tensor. A boolean tensor that indicates whether the lookup hits
+   *                      (True) or not (False). Data types supported: U8/QASYMM8
+   * @return N/A
+   */
+  void configure(const ITensor *lookups, const ITensor *keys, const ITensor *input, ITensor *output,
+                 ITensor *hits);
+  /** Static function to check if given info will lead to a valid configuration of @ref NECopy
+   *
+   * @param[in]  lookups  Lookups 1D tensor info.
+   *                      Data types supported: S32
+   * @param[in]  keys     Keys 1D tensor info. keys and input pair represent a map.
+   *                      Data types supported: S32
+   * @param[in]  input    Source tensor info.
+   *                      Data types supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32
+   * @param[in]  output   Destination tensor info. Data types and data layouts supported: Same as @p
+   * input.
+   * @param[in]  hits     Hits 1D tensor info. A boolean tensor that indicates whether the lookup
+   * hits (True) or not (False). Data types supported: U8/QASYMM8
+   *
+   * @return a status
+   */
+  static Status validate(const ITensorInfo *lookups, const ITensorInfo *keys,
+                         const ITensorInfo *input, const ITensorInfo *output,
+                         const ITensorInfo *hits);
+};
+}
+#endif /*__ARM_COMPUTE_NEHASHTABLELOOKUP_H__ */
diff --git a/runtimes/libs/ARMComputeEx/src/core/NEON/kernels/NEHashtableLookupKernel.cpp b/runtimes/libs/ARMComputeEx/src/core/NEON/kernels/NEHashtableLookupKernel.cpp
new file mode 100644 (file)
index 0000000..391337b
--- /dev/null
@@ -0,0 +1,181 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2018-2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/NEON/kernels/NEHashtableLookupKernel.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/Window.h"
+
+#include <unordered_map>
+
+using namespace arm_compute;
+
+namespace
+{
+constexpr size_t NOT_HIT = 0xFFFFFFFF;
+} // namespace
+
+NEHashtableLookupKernel::NEHashtableLookupKernel()
+    : _lookups(nullptr), _keys(nullptr), _input(nullptr), _output(nullptr), _hits{nullptr}
+{
+}
+
+void NEHashtableLookupKernel::configure(const ITensor *lookups, const ITensor *keys,
+                                        const ITensor *input, ITensor *output, ITensor *hits)
+{
+  ARM_COMPUTE_ERROR_ON_NULLPTR(lookups, keys, input, output, hits);
+  ARM_COMPUTE_ERROR_THROW_ON(
+      validate(lookups->info(), keys->info(), input->info(), output->info(), hits->info()));
+
+  _lookups = lookups;
+  _keys = keys;
+  _input = input;
+  _output = output;
+  _hits = hits;
+
+  // Auto initialize output if not initialized
+  auto out_shape{input->info()->tensor_shape()};
+  out_shape.set(out_shape.num_dimensions() - 1, lookups->info()->num_dimensions(), false);
+  auto_init_if_empty(*output->info(), out_shape, 1, input->info()->data_type(),
+                     input->info()->quantization_info());
+
+  // Auto initialize hits if not initialized
+  auto_init_if_empty(*hits->info(), lookups->info()->tensor_shape(), 1, DataType::U8);
+
+  INEKernel::configure(calculate_max_window(*output->info()));
+}
+
+Status NEHashtableLookupKernel::validate(const ITensorInfo *lookups, const ITensorInfo *keys,
+                                         const ITensorInfo *input, const ITensorInfo *output,
+                                         const ITensorInfo *hits)
+{
+  ARM_COMPUTE_ERROR_ON_NULLPTR(lookups, keys, input, output, hits);
+  ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(
+      input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::U16, DataType::S16,
+      DataType::U32, DataType::S32, DataType::F16, DataType::F32);
+  ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lookups, 1, DataType::S32);
+  ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(keys, 1, DataType::S32);
+
+  ARM_COMPUTE_ERROR_ON(input->num_dimensions() < 2 && input->num_dimensions() > 4);
+  ARM_COMPUTE_ERROR_ON(lookups->num_dimensions() > 1);
+  ARM_COMPUTE_ERROR_ON(keys->num_dimensions() > 1);
+  ARM_COMPUTE_ERROR_ON(keys->dimension(0) != input->dimension(input->num_dimensions() - 1));
+
+  // Validate in case of configured output
+  if (output->total_size() > 0)
+  {
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+    ARM_COMPUTE_ERROR_ON(input->num_dimensions() != output->num_dimensions());
+    ARM_COMPUTE_ERROR_ON(output->dimension(output->num_dimensions() - 1) != lookups->dimension(0));
+    for (size_t i = 0; i < output->num_dimensions() - 1; ++i)
+    {
+      ARM_COMPUTE_ERROR_ON(input->dimension(i) != output->dimension(i));
+    }
+  }
+
+  // Validate in case of configured hits
+  if (hits->total_size() > 0)
+  {
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(hits, 1, DataType::U8, DataType::QASYMM8);
+    ARM_COMPUTE_ERROR_ON(hits->dimension(0) != output->dimension(output->num_dimensions() - 1));
+    ARM_COMPUTE_ERROR_ON(hits->dimension(0) != lookups->dimension(0));
+    ARM_COMPUTE_ERROR_ON(hits->num_dimensions() > 1);
+  }
+
+  return Status{};
+}
+
+void NEHashtableLookupKernel::run(const Window &window, const ThreadInfo &info)
+{
+  ARM_COMPUTE_UNUSED(info);
+  ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+  ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
+
+  const size_t lookup_dim = _output->info()->num_dimensions() - 1;
+  const int const_0 = _output->info()->data_type() == DataType::QASYMM8
+                          ? _output->info()->quantization_info().offset
+                          : 0;
+
+  std::unordered_map<int32_t, size_t> key_index_map;
+  for (size_t n = 0; n < _keys->info()->dimension(0); ++n)
+  {
+    const int32_t key = *reinterpret_cast<int32_t *>(_keys->ptr_to_element({n}));
+    key_index_map[key] = n;
+  }
+  std::vector<size_t> lookup_indices;
+  for (size_t k = 0; k < _lookups->info()->dimension(0); ++k)
+  {
+    const int32_t key = *reinterpret_cast<int32_t *>(_lookups->ptr_to_element({k}));
+    const auto it = key_index_map.find(key);
+    if (it == key_index_map.end())
+    {
+      lookup_indices.emplace_back(NOT_HIT);
+      *_hits->ptr_to_element({k}) = 0;
+    }
+    else
+    {
+#if defined(ARM_COMPUTE_DEBUG_ENABLED)
+      if (it->second >= _keys->info()->dimension(0))
+        ARM_COMPUTE_ERROR("HashTable Lookup: Index out of bounds.");
+#endif // defined(ARM_COMPUTE_DEBUG_ENABLED)
+      lookup_indices.emplace_back(it->second);
+      *_hits->ptr_to_element({k}) = 1;
+    }
+  }
+
+  Window output_window{window};
+  output_window.set(Window::DimX,
+                    Window::Dimension(output_window.x().start(), output_window.x().end(),
+                                      _input->info()->dimension(0)));
+
+  Window out_slice = output_window.first_slice_window_4D();
+  do
+  {
+    Iterator output_it(_output, out_slice);
+
+    execute_window_loop(out_slice,
+                        [&](const Coordinates &id) {
+                          const auto lookup = lookup_indices.at(id[lookup_dim]);
+                          if (lookup == NOT_HIT)
+                          {
+                            memset(output_it.ptr(), const_0,
+                                   _output->info()->dimension(0) * _output->info()->element_size());
+                          }
+                          else
+                          {
+                            Coordinates input_id{id};
+                            input_id.set(lookup_dim, lookup);
+                            memcpy(output_it.ptr(), _input->ptr_to_element(input_id),
+                                   _output->info()->dimension(0) * _output->info()->element_size());
+                          }
+
+                        },
+                        output_it);
+
+  } while (window.slide_window_slice_4D(out_slice));
+}
diff --git a/runtimes/libs/ARMComputeEx/src/runtime/NEON/functions/NEHashtableLookup.cpp b/runtimes/libs/ARMComputeEx/src/runtime/NEON/functions/NEHashtableLookup.cpp
new file mode 100644 (file)
index 0000000..624185d
--- /dev/null
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2016-2018 ARM Limited.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "arm_compute/runtime/NEON/functions/NEHashtableLookup.h"
+
+#include "arm_compute/core/NEON/kernels/NEHashtableLookupKernel.h"
+#include "support/ToolchainSupport.h"
+
+using namespace arm_compute;
+
+void NEHashtableLookup::configure(const ITensor *lookups, const ITensor *keys, const ITensor *input,
+                                  ITensor *output, ITensor *hits)
+{
+  auto k = arm_compute::support::cpp14::make_unique<NEHashtableLookupKernel>();
+  k->configure(lookups, keys, input, output, hits);
+  _kernel = std::move(k);
+}
+
+Status NEHashtableLookup::validate(const ITensorInfo *lookups, const ITensorInfo *keys,
+                                   const ITensorInfo *input, const ITensorInfo *output,
+                                   const ITensorInfo *hits)
+{
+  return NEHashtableLookupKernel::validate(lookups, keys, input, output, hits);
+}
index cce6efc..ae76f33 100644 (file)
@@ -671,6 +671,35 @@ void KernelGenerator::visit(const model::operation::FullyConnectedNode &node)
   ActivationBuilder{*_execution_builder}.append(activation, output_alloc->handle());
 }
 
+void KernelGenerator::visit(const model::operation::HashtableLookupNode &node)
+{
+  const auto output_index{
+      node.getOutputs().at(model::operation::HashtableLookupNode::Output::OUTPUT)};
+  const auto hits_index{node.getOutputs().at(model::operation::HashtableLookupNode::Output::HITS)};
+
+  const auto lookups_index{
+      node.getInputs().at(model::operation::HashtableLookupNode::Input::LOOKUPS)};
+  const auto keys_index{node.getInputs().at(model::operation::HashtableLookupNode::Input::KEYS)};
+  const auto values_index{
+      node.getInputs().at(model::operation::HashtableLookupNode::Input::VALUES)};
+
+  auto output_alloc = _tensor_builder->at(output_index).get();
+  auto hits_alloc = _tensor_builder->at(hits_index).get();
+
+  auto lookups_alloc = _tensor_builder->at(lookups_index).get();
+  auto keys_alloc = _tensor_builder->at(keys_index).get();
+  auto values_alloc = _tensor_builder->at(values_index).get();
+
+  auto fn = nnfw::cpp14::make_unique<::arm_compute::NEHashtableLookup>();
+
+  fn->configure(lookups_alloc->handle(), keys_alloc->handle(), values_alloc->handle(),
+                output_alloc->handle(), hits_alloc->handle());
+
+  auto acl_fn = asAclFunction(std::move(fn));
+
+  _execution_builder->append(std::move(acl_fn));
+}
+
 void KernelGenerator::visit(const model::operation::L2NormalizationNode &node)
 {
   const auto ofm_index{node.getOutputs().at(0)};
index fe8f312..3fd90cc 100644 (file)
@@ -49,6 +49,7 @@ public:
   void visit(const model::operation::EmbeddingLookupNode &) override;
   void visit(const model::operation::FloorNode &) override;
   void visit(const model::operation::FullyConnectedNode &) override;
+  void visit(const model::operation::HashtableLookupNode &) override;
   void visit(const model::operation::L2NormalizationNode &) override;
   void visit(const model::operation::L2Pool2DNode &) override;
   void visit(const model::operation::LocalResponseNormalizationNode &) override;
index 9e052d3..7da326b 100644 (file)
@@ -111,6 +111,15 @@ void ShapeFixer::visit(const model::operation::FullyConnectedNode &node)
     _tensor_builder->dimCorrection(input_index, false);
 }
 
+void ShapeFixer::visit(const model::operation::HashtableLookupNode &node)
+{
+  const auto output_index{node.getOutputs().at(0)};
+  const auto values_index{
+      node.getInputs().at(model::operation::EmbeddingLookupNode::Input::VALUES)};
+  _tensor_builder->dimCorrection(values_index, false);
+  _tensor_builder->dimCorrection(output_index, false);
+}
+
 void ShapeFixer::visit(const model::operation::L2NormalizationNode &) { /* DO NOTHING */}
 
 void ShapeFixer::visit(const model::operation::L2Pool2DNode &) { /* DO NOTHING */}
index e8dada6..b86b638 100644 (file)
@@ -51,6 +51,7 @@ public:
   void visit(const model::operation::ExpNode &) override;
   void visit(const model::operation::FloorNode &) override;
   void visit(const model::operation::FullyConnectedNode &) override;
+  void visit(const model::operation::HashtableLookupNode &) override;
   void visit(const model::operation::L2NormalizationNode &) override;
   void visit(const model::operation::L2Pool2DNode &) override;
   void visit(const model::operation::LocalResponseNormalizationNode &) override;
index 2a09308..3bc8889 100644 (file)
@@ -4,7 +4,6 @@
 # Not support operations
 TrivialTest.BroadcastMulTwo
 GeneratedTests.dequantize
-GeneratedTests.hashtable_lookup*
 GeneratedTests.lsh_projection*
 GeneratedTests.mobilenet*
 GeneratedTests.space_to_depth*