[newrt] Implement FullyConnected kernel for CPU (#1943)
author김수진/동작제어Lab(SR)/Engineer/삼성전자 <sjsujin.kim@samsung.com>
Fri, 13 Jul 2018 07:10:12 +0000 (16:10 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Fri, 13 Jul 2018 07:10:12 +0000 (16:10 +0900)
* [newrt] Implement FullyConnected kernel for CPU

This commit implements FullyConnected kernel for CPU in new runtime.

- Other jobs
  - Add getNumberOfElements
  - Add zeroPoint(offset)
  - Add QuantizeMultiplierSmallerThanOne, GetQuantizedConvolutionMultipler

Signed-off-by: sjsujinkim <sjsujin.kim@samsung.com>
* Add TODO, Add weight/bias for quant8_asymm

* Add comment for GetQuantizedConvolutionMultipler

runtimes/new_runtime/src/internal/Model.h
runtimes/new_runtime/src/internal/cpu/InitializerGenerator.cc
runtimes/new_runtime/src/internal/cpu/StageGenerator.cc
runtimes/new_runtime/src/internal/kernels/cpufallback/FullyConnectedLayer.cc [new file with mode: 0644]
runtimes/new_runtime/src/internal/kernels/cpufallback/FullyConnectedLayer.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/kernels/cpufallback/OperationUtils.cc
runtimes/new_runtime/src/internal/kernels/cpufallback/OperationUtils.h
runtimes/new_runtime/src/internal/nnapi/kernel/Reader.h
runtimes/new_runtime/src/model.cc

index 629290a..0e27d83 100644 (file)
@@ -54,10 +54,12 @@ public:
   const std::vector<int32_t> &dims() const { return _dims; }
   int32_t type() const { return _type; }
   float scale() const { return _scale; }
-  void set(int32_t type, float scale)
+  int32_t offset() const { return _offset; }
+  void set(int32_t type, float scale, int32_t offset)
   {
     _type = type;
     _scale = scale;
+    _offset = offset;
   }
 
 public:
@@ -67,8 +69,10 @@ public:
 
 private:
   std::vector<int32_t> _dims;
+  // TODO: type, scale, offset should be moved to any other class.
   int32_t _type;
   float _scale;
+  int32_t _offset;
 };
 
 } // namespace operand
index a582824..a41ba7c 100644 (file)
@@ -4,6 +4,8 @@
 #include "internal/nnapi/kernel/View.h"
 #include "util/kernel/IndexIterator.h"
 
+#include "NeuralNetworks.h"
+
 namespace internal
 {
 namespace cpu
@@ -38,7 +40,77 @@ InitializerGenerator::generateWeight(const ::internal::tflite::op::Conv2D::impli
 Initializer
 InitializerGenerator::generateWeight(const ::internal::tflite::op::FullyConnected::Node &node)
 {
-  throw std::runtime_error("NYI");
+  const ::internal::tflite::operand::Index weight_index{node.param().weight_index};
+  const ::internal::tflite::operand::Index input_index{node.param().input_index};
+
+  const auto num_output = _ctx.at(weight_index).shape().dim(0);
+  auto weight_base = _ctx.at(weight_index).data().base();
+  auto weight_size = _ctx.at(weight_index).data().size();
+  auto weight_type = _ctx.at(weight_index).shape().type();
+
+  // NOTE We assume that input is a feature map
+  // TODO Remove this restriction!
+  const auto ifm_shape = _ctx.at(input_index).shape().asFeature();
+
+  switch (weight_type)
+  {
+    case ANEURALNETWORKS_TENSOR_FLOAT32:
+    {
+      return [num_output, ifm_shape, weight_base, weight_size](::arm_compute::ITensor &tensor) {
+        const ::nnfw::util::kernel::Shape ker_shape{num_output, ifm_shape.C, ifm_shape.H,
+                                                    ifm_shape.W};
+        const ::internal::nnapi::kernel::Reader<float> from{ker_shape, weight_base, weight_size};
+
+        ::nnfw::util::kernel::iterate(ker_shape)
+            << [&](uint32_t nth, uint32_t ch, uint32_t row, uint32_t col) {
+                 const auto value = from.at(nth, ch, row, col);
+
+                 uint32_t offset = 0;
+
+                 // NNAPI uses NHWC ordering
+                 offset += nth * ifm_shape.H * ifm_shape.W * ifm_shape.C;
+                 offset += row * ifm_shape.W * ifm_shape.C;
+                 offset += col * ifm_shape.C;
+                 offset += ch;
+
+                 const ::arm_compute::Coordinates coordinate{offset};
+
+                 auto into = reinterpret_cast<float *>(tensor.ptr_to_element(coordinate));
+
+                 *into = value;
+               };
+      };
+    }
+    case ANEURALNETWORKS_TENSOR_QUANT8_ASYMM:
+    {
+      return [num_output, ifm_shape, weight_base, weight_size](::arm_compute::ITensor &tensor) {
+        const ::nnfw::util::kernel::Shape ker_shape{num_output, ifm_shape.C, ifm_shape.H,
+                                                    ifm_shape.W};
+        const ::internal::nnapi::kernel::Reader<uint8_t> from{ker_shape, weight_base, weight_size};
+        ::nnfw::util::kernel::iterate(ker_shape)
+            << [&](uint32_t nth, uint32_t ch, uint32_t row, uint32_t col) {
+                 const auto value = from.at(nth, ch, row, col);
+                 uint32_t offset = 0;
+
+                 // NNAPI uses NHWC ordering
+                 offset += nth * ifm_shape.H * ifm_shape.W * ifm_shape.C;
+                 offset += row * ifm_shape.W * ifm_shape.C;
+                 offset += col * ifm_shape.C;
+                 offset += ch;
+
+                 const ::arm_compute::Coordinates coordinate{offset};
+
+                 auto into = reinterpret_cast<uint8_t *>(tensor.ptr_to_element(coordinate));
+
+                 *into = value;
+               };
+      };
+    }
+    default:
+    {
+      throw std::runtime_error("Not supported weight type");
+    }
+  }
 }
 
 Initializer
@@ -69,7 +141,51 @@ InitializerGenerator::generateBias(const ::internal::tflite::op::Conv2D::implici
 Initializer
 InitializerGenerator::generateBias(const ::internal::tflite::op::FullyConnected::Node &node)
 {
-  throw std::runtime_error("NYI");
+  const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
+
+  auto bias_base = _ctx.at(bias_index).data().base();
+  auto bias_type = _ctx.at(bias_index).shape().type();
+  const auto bias_size = _ctx.at(bias_index).shape().asVector();
+
+  switch (bias_type)
+  {
+    case ANEURALNETWORKS_TENSOR_FLOAT32:
+    {
+      return [bias_base, bias_size](::arm_compute::ITensor &tensor) {
+        for (uint32_t n = 0; n < bias_size; ++n)
+        {
+          const ::arm_compute::Coordinates coordinate{n};
+
+          float *into = reinterpret_cast<float *>(tensor.ptr_to_element(coordinate));
+
+          const float *from = reinterpret_cast<const float *>(bias_base) + n;
+          const auto value = *from;
+
+          *into = value;
+        }
+      };
+    }
+    case ANEURALNETWORKS_TENSOR_QUANT8_ASYMM:
+    {
+      return [bias_base, bias_size](::arm_compute::ITensor &tensor) {
+        for (uint32_t n = 0; n < bias_size; ++n)
+        {
+          const ::arm_compute::Coordinates coordinate{n};
+
+          uint8_t *into = reinterpret_cast<uint8_t *>(tensor.ptr_to_element(coordinate));
+
+          const uint8_t *from = reinterpret_cast<const uint8_t *>(bias_base) + n;
+          const auto value = *from;
+
+          *into = value;
+        }
+      };
+    }
+    default:
+    {
+      throw std::runtime_error("Not supported bias type");
+    }
+  }
 }
 
 } // namespace arm_compute
index 9f5042b..b944398 100644 (file)
@@ -7,6 +7,7 @@
 #include "internal/kernels/cpufallback/AvgPoolLayer.h"
 #include "internal/kernels/cpufallback/MaxPoolLayer.h"
 #include "internal/kernels/cpufallback/ConcatLayer.h"
+#include "internal/kernels/cpufallback/FullyConnectedLayer.h"
 
 #include "logging.h"
 
@@ -361,7 +362,61 @@ Stage StageGenerator::generate(const ::internal::tflite::op::Concat::Node &node)
 
 Stage StageGenerator::generate(const ::internal::tflite::op::FullyConnected::Node &node)
 {
-  throw std::runtime_error("NYI");
+  VERBOSE(FullyConnected) << "generate CPU FullyConnected" << std::endl;
+
+  const ::internal::tflite::operand::Index output_index{node.param().output_index};
+  const ::internal::tflite::operand::Index input_index{node.param().input_index};
+  const ::internal::tflite::operand::Index weight_index{node.param().weight_index};
+  const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
+  const ::internal::tflite::operand::Index activation_index{node.param().activation_index};
+
+  // Construct operation parameters
+  struct Param
+  {
+    int output_index;
+    int input_index;
+    int weight_index;
+    int bias_index;
+
+    ::internal::tflite::operand::Shape ofm_shape{1};
+    ::internal::tflite::operand::Shape ifm_shape{1};
+    ::internal::tflite::operand::Shape weight_shape{1};
+    ::internal::tflite::operand::Shape bias_shape{1};
+
+    FuseCode activation;
+  };
+
+  Param param;
+
+  param.output_index = output_index.asInt();
+  param.input_index = input_index.asInt();
+  param.weight_index = weight_index.asInt();
+  param.bias_index = bias_index.asInt();
+
+  param.ofm_shape = _ctx.at(output_index).shape();
+  param.ifm_shape = _ctx.at(input_index).shape();
+  param.weight_shape = _ctx.at(weight_index).shape();
+  param.bias_shape = _ctx.at(bias_index).shape();
+
+  param.activation = static_cast<FuseCode>(_ctx.at(activation_index).asScalar<int32_t>());
+
+  auto tensors = _tensor_builder;
+
+  return [tensors, param](IExecutionBuilder &builder) {
+    auto output_alloc = tensors->at(::internal::tflite::operand::Index{param.output_index}).get();
+    auto input_alloc = tensors->at(::internal::tflite::operand::Index{param.input_index}).get();
+    auto weight_alloc = tensors->at(::internal::tflite::operand::Index{param.weight_index}).get();
+    auto bias_alloc = tensors->at(::internal::tflite::operand::Index{param.bias_index}).get();
+
+    std::unique_ptr<::internal::kernels::cpu::FullyConnectedLayer> fn{
+        new ::internal::kernels::cpu::FullyConnectedLayer};
+
+    fn->configure(input_alloc->buffer(), param.ifm_shape, weight_alloc->buffer(),
+                  param.weight_shape, bias_alloc->buffer(), param.bias_shape, param.activation,
+                  output_alloc->buffer(), param.ofm_shape);
+
+    builder.append(std::move(fn));
+  };
 }
 
 Stage StageGenerator::generate(const ::internal::tflite::op::Reshape::Node &node)
diff --git a/runtimes/new_runtime/src/internal/kernels/cpufallback/FullyConnectedLayer.cc b/runtimes/new_runtime/src/internal/kernels/cpufallback/FullyConnectedLayer.cc
new file mode 100644 (file)
index 0000000..98cf250
--- /dev/null
@@ -0,0 +1,131 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FullyConnectedLayer.h"
+
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "internal/kernels/cpufallback/OperationUtils.h"
+
+#include <mutex>
+
+namespace internal
+{
+namespace kernels
+{
+namespace cpu
+{
+
+// executionMutex is used to protect concurrent access of non-threadsafe resources
+// like gemmlowp::GemmContext.
+// std::mutex is safe for pthreads on Android.
+static std::mutex executionMutex;
+bool FullyConnectedLayer::fullyConnectedFloat32()
+{
+  float output_activation_min, output_activation_max;
+  CalculateActivationRangeFloat(_activation, &output_activation_min, &output_activation_max);
+  // b/80425683, optimized implementation produces incorrect results when the
+  // number of input elements is the squre of batch_size.
+  uint32_t batch_size = getSizeOfDimension(_outputShape, 0);
+  uint32_t input_n_elements = getNumberOfElements(_inputShape);
+  if (batch_size * batch_size == input_n_elements)
+  {
+    ::tflite::reference_ops::FullyConnected(
+        reinterpret_cast<const float *>(_inputData), convertShapeToDims(_inputShape),
+        reinterpret_cast<const float *>(_weightsData), convertShapeToDims(_weightsShape),
+        reinterpret_cast<const float *>(_biasData), convertShapeToDims(_biasShape),
+        output_activation_min, output_activation_max, reinterpret_cast<float *>(_outputData),
+        convertShapeToDims(_outputShape));
+  }
+  else
+  {
+    ::tflite::optimized_ops::FullyConnected(
+        reinterpret_cast<const float *>(_inputData), convertShapeToDims(_inputShape),
+        reinterpret_cast<const float *>(_weightsData), convertShapeToDims(_weightsShape),
+        reinterpret_cast<const float *>(_biasData), convertShapeToDims(_biasShape),
+        output_activation_min, output_activation_max, reinterpret_cast<float *>(_outputData),
+        convertShapeToDims(_outputShape));
+  }
+  return true;
+}
+
+bool FullyConnectedLayer::fullyConnectedQuant8()
+{
+  int32_t inputOffset = -_inputShape.offset;
+  int32_t weightsOffset = -_weightsShape.offset;
+  int32_t outputOffset = _outputShape.offset;
+  float real_multiplier = 0.0;
+  int32_t output_multiplier = 0;
+  int32_t output_shift = 0;
+  int32_t output_activation_min = 0;
+  int32_t output_activation_max = 0;
+  // Caution : 'Convolution' can make misleading. It seems it is just math term.
+  if (!GetQuantizedConvolutionMultipler(_inputShape, _weightsShape, _biasShape, _outputShape,
+                                        &real_multiplier) ||
+      !QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, &output_shift))
+  {
+    return false;
+  }
+  CalculateActivationRangeUint8(_activation, _outputShape, &output_activation_min,
+                                &output_activation_max);
+  static gemmlowp::GemmContext gemm_context;
+  // Prevent concurrent executions that access gemm_context.
+  std::unique_lock<std::mutex> lock(executionMutex);
+  // Alow gemmlowp automatically decide how many threads to use.
+  gemm_context.set_max_num_threads(0);
+  ::tflite::optimized_ops::FullyConnected(
+      _inputData, convertShapeToDims(_inputShape), inputOffset, _weightsData,
+      convertShapeToDims(_weightsShape), weightsOffset,
+      reinterpret_cast<const int32_t *>(_biasData), convertShapeToDims(_biasShape), outputOffset,
+      output_multiplier, output_shift, output_activation_min, output_activation_max, _outputData,
+      convertShapeToDims(_outputShape), &gemm_context);
+  return true;
+}
+
+void FullyConnectedLayer::configure(
+    uint8_t *inputData, const internal::tflite::operand::Shape inputShape, uint8_t *weightsData,
+    const internal::tflite::operand::Shape weightsShape, uint8_t *biasData,
+    const internal::tflite::operand::Shape biasShape, FuseCode activation, uint8_t *outputData,
+    const internal::tflite::operand::Shape outputShape)
+{
+  _inputData = inputData;
+  _inputShape = convertShape(inputShape);
+  _inputType = inputShape.type();
+  _weightsData = weightsData;
+  _weightsShape = convertShape(weightsShape);
+  _biasData = biasData;
+  _biasShape = convertShape(biasShape);
+  _activation = activation;
+  _outputData = outputData;
+  _outputShape = convertShape(outputShape);
+}
+
+void FullyConnectedLayer::run()
+{
+  if (_inputType == static_cast<uint32_t>(OperandType::TENSOR_FLOAT32))
+  {
+    fullyConnectedFloat32();
+  }
+  else if (_inputType == static_cast<uint32_t>(OperandType::TENSOR_QUANT8_ASYMM))
+  {
+    fullyConnectedQuant8();
+  }
+}
+
+} // namespace cpu
+} // namespace kernels
+} // namespace internal
diff --git a/runtimes/new_runtime/src/internal/kernels/cpufallback/FullyConnectedLayer.h b/runtimes/new_runtime/src/internal/kernels/cpufallback/FullyConnectedLayer.h
new file mode 100644 (file)
index 0000000..50b40af
--- /dev/null
@@ -0,0 +1,58 @@
+#ifndef __INTERNAL_KERNELS_CPU_FULLYCONNECTEDLAYER_H__
+#define __INTERNAL_KERNELS_CPU_FULLYCONNECTEDLAYER_H__
+
+#include <NeuralNetworks.h>
+
+#include <arm_compute/runtime/IFunction.h>
+
+#include "internal/Model.h"
+#include "internal/kernels/cpufallback/OperationUtils.h"
+
+using namespace internal::kernels::cpu;
+
+namespace internal
+{
+namespace kernels
+{
+namespace cpu
+{
+
+class FullyConnectedLayer : public ::arm_compute::IFunction
+{
+public:
+  FullyConnectedLayer() {}
+
+public:
+  bool fullyConnectedFloat32();
+
+  bool fullyConnectedQuant8();
+
+  void configure(uint8_t *inputData, const internal::tflite::operand::Shape inputShape,
+                 uint8_t *weightsData, const internal::tflite::operand::Shape weightsShape,
+                 uint8_t *biasData, const internal::tflite::operand::Shape biasShape,
+                 FuseCode activation, uint8_t *outputData,
+                 const internal::tflite::operand::Shape outputShape);
+
+  void run();
+
+private:
+  uint8_t *_inputData;
+  uint8_t *_weightsData;
+  uint8_t *_biasData;
+  uint8_t *_outputData;
+
+  Shape _inputShape;
+  Shape _weightsShape;
+  Shape _biasShape;
+  Shape _outputShape;
+
+  FuseCode _activation;
+
+  int32_t _inputType;
+};
+
+} // namespace cpu
+} // namespace kernels
+} // namespace internal
+
+#endif // __INTERNAL_KERNELS_CPU_FULLYCONNECTEDLAYER_H__
index 685b386..7ca51a9 100644 (file)
@@ -2,6 +2,7 @@
 
 #include <cmath>
 #include <algorithm>
+#include <cassert>
 
 namespace internal
 {
@@ -12,6 +13,16 @@ namespace cpu
 
 uint32_t getNumberOfDimensions(const Shape &shape) { return shape.dimensions.size(); }
 
+uint32_t getNumberOfElements(const Shape &shape)
+{
+  uint32_t count = 1;
+  for (size_t i = 0; i < shape.dimensions.size(); i++)
+  {
+    count *= shape.dimensions[i];
+  }
+  return count;
+}
+
 uint32_t getSizeOfDimension(const Shape &shape, uint32_t dimensionIdx)
 {
   if (dimensionIdx >= shape.dimensions.size())
@@ -22,6 +33,49 @@ uint32_t getSizeOfDimension(const Shape &shape, uint32_t dimensionIdx)
   return shape.dimensions[dimensionIdx];
 }
 
+bool QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier,
+                                      int32_t *right_shift)
+{
+  assert(double_multiplier >= 0.);
+  assert(double_multiplier < 1.);
+  if (double_multiplier == 0.)
+  {
+    *quantized_multiplier = 0;
+    *right_shift = 0;
+    return true;
+  }
+  assert(double_multiplier > 0.);
+  const double q = std::frexp(double_multiplier, right_shift);
+  *right_shift *= -1;
+  int64_t q_fixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
+  assert(q_fixed <= (1ll << 31));
+  if (q_fixed == (1ll << 31))
+  {
+    q_fixed /= 2;
+    --*right_shift;
+  }
+  assert(*right_shift >= 0);
+  assert(q_fixed <= std::numeric_limits<int32_t>::max());
+  *quantized_multiplier = static_cast<int32_t>(q_fixed);
+  return true;
+}
+
+bool GetQuantizedConvolutionMultipler(const Shape &inputShape, const Shape &filterShape,
+                                      const Shape &biasShape, const Shape &outputShape,
+                                      float *multiplier)
+{
+  const float input_product_scale = inputShape.scale * filterShape.scale;
+  const float bias_scale = biasShape.scale;
+  const float output_scale = outputShape.scale;
+  // The following conditions must be guaranteed by the training pipeline.
+  assert(std::abs(input_product_scale - bias_scale) <=
+         1e-6 * std::min(input_product_scale, bias_scale));
+  assert(input_product_scale >= 0);
+  assert(input_product_scale < output_scale);
+  *multiplier = input_product_scale / output_scale;
+  return true;
+}
+
 void CalculateActivationRangeFloat(int32_t activation, float *activation_min, float *activation_max)
 {
   if (activation == ANEURALNETWORKS_FUSED_RELU)
index b2ab60b..d72d348 100644 (file)
@@ -39,6 +39,8 @@ struct Shape
 
 uint32_t getNumberOfDimensions(const Shape &shape);
 
+uint32_t getNumberOfElements(const Shape &shape);
+
 uint32_t getSizeOfDimension(const Shape &shape, uint32_t dimensionIdx);
 
 inline ::tflite::Dims<4> convertShapeToDims(const Shape &shape)
@@ -66,6 +68,13 @@ inline ::tflite::Dims<4> convertShapeToDims(const Shape &shape)
   return dims;
 }
 
+__wur bool QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier,
+                                            int32_t *right_shift);
+
+__wur bool GetQuantizedConvolutionMultipler(const Shape &inputShape, const Shape &filterShape,
+                                            const Shape &biasShape, const Shape &outputShape,
+                                            float *multiplier);
+
 void CalculateActivationRangeFloat(int32_t activation, float *activation_min,
                                    float *activation_max);
 
index 7ae3bed..a0a0c29 100644 (file)
@@ -11,9 +11,7 @@ namespace nnapi
 namespace kernel
 {
 
-template <typename T> class Reader;
-
-template <> class Reader<float> final : public nnfw::util::kernel::Reader<float>
+template <typename T> class Reader final : public nnfw::util::kernel::Reader<T>
 {
 public:
   Reader(const ::nnfw::util::kernel::Shape &shape, const uint8_t *base, size_t size)
@@ -26,7 +24,7 @@ public:
   const nnfw::util::kernel::Shape &shape(void) const { return _shape; }
 
 public:
-  float at(uint32_t nth, uint32_t ch, uint32_t row, uint32_t col) const override
+  T at(uint32_t nth, uint32_t ch, uint32_t row, uint32_t col) const override
   {
     // NNAPI uses NHWC ordering
     uint32_t index = 0;
@@ -36,7 +34,7 @@ public:
     index += col * _shape.C;
     index += ch;
 
-    const float *ptr = reinterpret_cast<const float *>(_base);
+    const T *ptr = reinterpret_cast<const T *>(_base);
 
     return ptr[index];
   }
index 8fe005e..891584e 100644 (file)
@@ -30,7 +30,7 @@ int ANeuralNetworksModel_addOperand(ANeuralNetworksModel *model,
     shape.dim(axis) = type->dimensions[axis];
   }
 
-  shape.set(type->type, type->scale);
+  shape.set(type->type, type->scale, type->zeroPoint);
 
   model->deref().operands().append(shape);