[nnc/interpreter] Quantized kernels for supporting InceptionV4 (#8669)
authorPavel Iliutchenko/AI Tools Lab /SRR/Engineer/Samsung Electronics <p.iliutchenk@samsung.com>
Thu, 14 Nov 2019 16:35:24 +0000 (19:35 +0300)
committerAlexander Efimov/AI Tools Lab /SRR/Engineer/Samsung Electronics <a.efimov@samsung.com>
Thu, 14 Nov 2019 16:35:24 +0000 (19:35 +0300)
* Supported quantized kernels for operations: Concat, FC, MaxPool,
Softmax

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/nnc/backends/interpreter/Interpreter.cpp
compiler/nnc/backends/interpreter/ops/Concat.cpp [new file with mode: 0644]
compiler/nnc/backends/interpreter/ops/Concat.h
compiler/nnc/backends/interpreter/ops/FullyConnected.cpp
compiler/nnc/backends/interpreter/ops/FullyConnected.h
compiler/nnc/backends/interpreter/ops/MaxPool2D.cpp
compiler/nnc/backends/interpreter/ops/MaxPool2D.h
compiler/nnc/backends/interpreter/ops/QuantizationHelpers.cpp [deleted file]
compiler/nnc/backends/interpreter/ops/QuantizationHelpers.h
compiler/nnc/backends/interpreter/ops/Softmax.cpp
compiler/nnc/backends/interpreter/ops/Softmax.h

index fcc5c97..fd139a2 100644 (file)
@@ -25,6 +25,7 @@
 #include "ops/DepthwiseConv2D.h"
 #include "ops/Div.h"
 #include "ops/ELU.h"
+#include "ops/Fill.h"
 #include "ops/FullyConnected.h"
 #include "ops/Gather.h"
 #include "ops/LeakyReLU.h"
@@ -141,8 +142,18 @@ void NNInterpreter::visit(ops::ConstantOp &op) { setOutputTensors(op, {op.getVal
 void NNInterpreter::visit(ops::ConcatOp &op)
 {
   auto inputs = getInputTensors(op);
-  auto outputs = Concat<float>(inputs, op.getOutputShape(0), op.getAxis())();
-  setOutputTensors(op, std::move(outputs));
+  auto outputs = getOutputTensors(op);
+  switch (inputs[0].get().getElementType())
+  {
+    case mir::DataType::FLOAT32:
+      Concatenation(inputs, op.getAxis(), outputs[0]);
+      break;
+    case mir::DataType::UINT8:
+      ConcatenationWithScaling(inputs, op.getAxis(), outputs[0]);
+      break;
+    default:
+      throw std::runtime_error("NYI");
+  }
 }
 
 void NNInterpreter::visit(ops::Conv2DOp &op)
@@ -167,7 +178,18 @@ void NNInterpreter::visit(ops::Conv2DOp &op)
 void NNInterpreter::visit(ops::MaxPool2DOp &op)
 {
   auto inputs = getInputTensors(op);
-  auto outputs = MaxPool2D(inputs[0], op)();
+  std::vector<mir::TensorVariant> outputs;
+  switch (inputs[0].get().getElementType())
+  {
+    case mir::DataType::FLOAT32:
+      outputs = MaxPool2D(inputs[0], op)();
+      break;
+    case mir::DataType::UINT8:
+      outputs = QuantizedMaxPool2D(op, inputs[0]);
+      break;
+    default:
+      throw std::runtime_error("NYI");
+  }
   setOutputTensors(op, std::move(outputs));
 }
 
@@ -194,15 +216,38 @@ void NNInterpreter::visit(ops::SigmoidOp &op)
 
 void NNInterpreter::visit(ops::SoftmaxOp &op)
 {
-  auto args = getInputTensors(op);
-  auto results = getOutputTensors(op);
-  Softmax(args[0], op.getAxis(), results[0]);
+  auto inputs = getInputTensors(op);
+  assert(inputs.size() == 1);
+  auto outputs = getOutputTensors(op);
+  switch (inputs[0].get().getElementType())
+  {
+    case mir::DataType::FLOAT32:
+      Softmax(inputs[0], op.getAxis(), outputs[0]);
+      break;
+    case mir::DataType::UINT8:
+      QuantizedSoftmax(inputs[0], op.getAxis(), outputs[0]);
+      break;
+    default:
+      throw std::runtime_error("NYI");
+  }
 }
 
 void NNInterpreter::visit(ops::FullyConnectedOp &op)
 {
   auto inputs = getInputTensors(op);
-  auto outputs = FullyConnected(inputs[0], inputs[1], op)();
+  std::vector<mir::TensorVariant> outputs;
+  switch (inputs[0].get().getElementType())
+  {
+    case mir::DataType::FLOAT32:
+      outputs = FullyConnected(inputs[0], inputs[1], op)();
+      break;
+    case mir::DataType::UINT8:
+      assert(inputs.size() == 3);
+      outputs = QuantizedFC(op, inputs[0], inputs[1], inputs[2]);
+      break;
+    default:
+      throw std::runtime_error("NYI");
+  }
   setOutputTensors(op, std::move(outputs));
 }
 
diff --git a/compiler/nnc/backends/interpreter/ops/Concat.cpp b/compiler/nnc/backends/interpreter/ops/Concat.cpp
new file mode 100644 (file)
index 0000000..05c83a1
--- /dev/null
@@ -0,0 +1,152 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Concat.h"
+
+#include <cmath>
+#include <cstring>
+
+namespace nnc
+{
+
+void Concatenation(const std::vector<std::reference_wrapper<const mir::TensorVariant>> &inputs,
+                   int axis, mir::TensorVariant &output)
+{
+  const auto &output_shape = output.getShape();
+  const size_t inputs_count = inputs.size();
+  const int32_t concat_dims = output_shape.rank();
+  int64_t concat_size = 0;
+  for (size_t i = 0; i < inputs_count; i++)
+  {
+    const auto &input_shape = inputs[i].get().getShape();
+    assert(input_shape.rank() == concat_dims);
+    for (int32_t j = 0; j < concat_dims; j++)
+    {
+      if (j != axis)
+      {
+        assert(input_shape.dim(j) == output_shape.dim(j));
+      }
+    }
+    concat_size += input_shape.dim(axis);
+  }
+  assert(concat_size == output_shape.dim(axis));
+  // Outer size before axis
+  int32_t outer_size = 1;
+  for (int32_t i = 0; i < axis; i++)
+    outer_size *= output_shape.dim(i);
+  // Inner size after axis
+  int32_t base_inner_size = 1;
+  for (int32_t i = axis + 1; i < concat_dims; i++)
+    base_inner_size *= output_shape.dim(i);
+  // flatten = outer_size * dim(axis) * base_inner_size;
+  std::vector<int32_t> copy_sizes;
+  std::vector<char *> input_ptrs;
+  for (size_t i = 0; i < inputs_count; i++)
+  {
+    const auto input_shape = inputs[i].get().getShape();
+    copy_sizes.push_back(input_shape.dim(axis) * base_inner_size);
+    input_ptrs.push_back(inputs[i].get().atOffset(0));
+  }
+
+  char *output_ptr = output.atOffset(0);
+  const size_t elem_size = inputs[0].get().getElementSize();
+  for (int32_t i = 0; i < outer_size; i++)
+  {
+    for (size_t j = 0; j < inputs_count; j++)
+    {
+      std::memcpy(output_ptr, input_ptrs[j], copy_sizes[j] * elem_size);
+      output_ptr += copy_sizes[j] * elem_size;
+      input_ptrs[j] += copy_sizes[j] * elem_size;
+    }
+  }
+}
+
+void ConcatenationWithScaling(
+    const std::vector<std::reference_wrapper<const mir::TensorVariant>> &inputs, int axis,
+    mir::TensorVariant &output)
+{
+  const size_t inputs_count = inputs.size();
+  std::vector<int32_t> input_zeropoints(inputs_count);
+  std::vector<float> input_scales(inputs_count);
+  const auto &output_shape = output.getShape();
+  const int32_t concat_dimensions = output_shape.rank();
+  int64_t concat_size = 0;
+  for (size_t i = 0; i < inputs_count; i++)
+  {
+    const auto &input_type = inputs[i].get().getType();
+    assert(input_type.isQuantized());
+    assert(input_type.getElementType() == mir::DataType::UINT8);
+    const auto &input_shape = input_type.getShape();
+    assert(input_shape.rank() == concat_dimensions);
+
+    for (int32_t j = 0; j < concat_dimensions; j++)
+      if (j != axis)
+        assert(input_shape.dim(j) == output_shape.dim(j));
+
+    concat_size += input_shape.dim(axis);
+    input_zeropoints[i] = input_type.getQuantization().getZeroPoint();
+    input_scales[i] = input_type.getQuantization().getScale();
+  }
+  assert(concat_size == output_shape.dim(axis));
+
+  const auto &output_type = output.getType();
+  assert(output_type.isQuantized());
+  int32_t output_zeropoint = output_type.getQuantization().getZeroPoint();
+  float output_scale = output_type.getQuantization().getScale();
+
+  // Outer size before axis
+  int32_t outer_size = 1;
+  for (int32_t i = 0; i < axis; i++)
+    outer_size *= output_shape.dim(i);
+  // Inner size after axis
+  int32_t base_inner_size = 1;
+  for (int32_t i = axis + 1; i < concat_dimensions; i++)
+    base_inner_size *= output_shape.dim(i);
+  // flatten = outer_size * dim(axis) * base_inner_size;
+
+  uint8_t *output_ptr = reinterpret_cast<uint8_t *>(output.atOffset(0));
+
+  const float inverse_output_scale = 1.f / output_scale;
+  for (int k = 0; k < outer_size; k++)
+  {
+    for (size_t i = 0; i < inputs_count; ++i)
+    {
+      const mir::TensorVariant &input = inputs[i];
+      const int copy_size = input.getShape().dim(axis) * base_inner_size;
+      const char *input_data = input.atOffset(0) + k * copy_size;
+      const uint8_t *input_ptr = reinterpret_cast<const uint8_t *>(input_data);
+      if (input_zeropoints[i] == output_zeropoint && input_scales[i] == output_scale)
+      {
+        std::memcpy(output_ptr, input_ptr, copy_size);
+      }
+      else
+      {
+        const float scale = input_scales[i] * inverse_output_scale;
+        const float bias = -input_zeropoints[i] * scale;
+        for (int j = 0; j < copy_size; ++j)
+        {
+          const int32_t value =
+              static_cast<int32_t>(std::round(input_ptr[j] * scale + bias)) + output_zeropoint;
+          output_ptr[j] = static_cast<uint8_t>(std::max(std::min(255, value), 0));
+        }
+      }
+      output_ptr += copy_size;
+    }
+  }
+}
+
+} // namespace nnc
index 31632ec..8ab4139 100644 (file)
  * limitations under the License.
  */
 
-#ifndef _NNC_CORE_BACKEND_INTERPRETER_FILL_IMPL_
-#define _NNC_CORE_BACKEND_INTERPRETER_FILL_IMPL_
+#ifndef _NNC_CORE_BACKEND_INTERPRETER_CONCAT_IMPL_
+#define _NNC_CORE_BACKEND_INTERPRETER_CONCAT_IMPL_
 
-#include "Fill.h"
+#include <mir/TensorVariant.h>
 
 namespace nnc
 {
 
-template <typename T> class Concat : public Fill<T>
-{
-public:
-  Concat(const std::vector<std::reference_wrapper<const mir::TensorVariant>> &inputs,
-         const mir::Shape &outputShape, int32_t axis)
-      : Fill<T>(outputShape, getSingleFunction(inputs, axis))
-  {
-  }
-
-private:
-  const std::function<T(const mir::Index &)>
-  getSingleFunction(const std::vector<std::reference_wrapper<const mir::TensorVariant>> &inputs,
-                    int32_t axis)
-  {
-    std::vector<mir::Tensor<T>> inputAccessors;
-    inputAccessors.reserve(inputs.size());
-    for (auto &in : inputs)
-      inputAccessors.emplace_back(in);
-
-    return std::function<T(const mir::Index &)>([inputAccessors, axis](const mir::Index &id) -> T {
-      unsigned int mi = 0;
-      int32_t along_axis = id.at(axis);
-
-      while (along_axis >= inputAccessors.at(mi).getShape().dim(axis))
-      {
-        along_axis -= inputAccessors[mi].getShape().dim(axis);
-        mi++;
-      }
-
-      mir::Index local_id = id;
-      local_id.at(axis) = along_axis;
+void Concatenation(const std::vector<std::reference_wrapper<const mir::TensorVariant>> &inputs,
+                   int axis, mir::TensorVariant &output);
 
-      return inputAccessors[mi].at(local_id);
-    });
-  }
-};
+void ConcatenationWithScaling(
+    const std::vector<std::reference_wrapper<const mir::TensorVariant>> &inputs, int axis,
+    mir::TensorVariant &output);
 
 } // namespace nnc
 
-#endif //_NNC_CORE_BACKEND_INTERPRETER_FILL_IMPL_
+#endif //_NNC_CORE_BACKEND_INTERPRETER_CONCAT_IMPL_
index 02a4791..b309c70 100644 (file)
@@ -1,5 +1,6 @@
 /*
  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,6 +17,8 @@
 
 #include "FullyConnected.h"
 
+#include "QuantizationHelpers.h"
+
 namespace nnc
 {
 
@@ -106,4 +109,80 @@ std::vector<mir::TensorVariant> FullyConnected::operator()()
   return {res};
 }
 
+std::vector<mir::TensorVariant> QuantizedFC(const mir::ops::FullyConnectedOp &op,
+                                            const mir::TensorVariant &input,
+                                            const mir::TensorVariant &weights,
+                                            const mir::TensorVariant &bias)
+{
+  const auto &input_type = input.getType();
+  const auto &weights_type = weights.getType();
+  const auto &bias_type = bias.getType();
+  const auto &output_type = op.getOutput(0)->getType();
+
+  assert(input_type.isQuantized());
+  assert(weights_type.isQuantized());
+  assert(bias_type.isQuantized());
+  assert(output_type.isQuantized());
+  assert(input_type.getElementType() == mir::DataType::UINT8);
+  assert(weights_type.getElementType() == mir::DataType::UINT8);
+  assert(bias_type.getElementType() == mir::DataType::INT32);
+
+  int32_t input_offset = -input_type.getQuantization().getZeroPoint();
+  int32_t weights_offset = -weights_type.getQuantization().getZeroPoint();
+  int32_t output_offset = output_type.getQuantization().getZeroPoint();
+
+  double input_scale = input_type.getQuantization().getScale();
+  double weights_scale = weights_type.getQuantization().getScale();
+  double output_scale = output_type.getQuantization().getScale();
+
+  double real_multiplier = input_scale * weights_scale / output_scale;
+  int32_t output_multiplier = 0;
+  int output_shift = 0;
+  QuantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
+
+  const mir::Shape &in_shape = input.getShape();
+  const mir::Shape &weights_shape = weights.getShape();
+  const mir::Shape &out_shape = op.getOutputShape(0);
+
+  const int32_t batches = in_shape.dim(0);
+  assert(in_shape.rank() == 2);
+  assert(weights_shape.rank() == 2);
+  assert(in_shape.dim(1) == weights_shape.dim(0));
+  const int32_t accum_depth = weights_shape.dim(0);
+  const int32_t output_depth = weights_shape.dim(1);
+
+  uint8_t *input_data = reinterpret_cast<uint8_t *>(input.atOffset(0));
+  uint8_t *weights_data = reinterpret_cast<uint8_t *>(weights.atOffset(0));
+  int32_t *bias_data = reinterpret_cast<int32_t *>(bias.atOffset(0));
+
+  mir::TensorType res_type(mir::DataType::UINT8, out_shape, output_type.getQuantization());
+  mir::TensorVariant res(res_type);
+  uint8_t *output_data = reinterpret_cast<uint8_t *>(res.atOffset(0));
+
+  int32_t output_min = std::numeric_limits<uint8_t>::min();
+  int32_t output_max = std::numeric_limits<uint8_t>::max();
+
+  for (int32_t b = 0; b < batches; ++b)
+  {
+    for (int32_t out_c = 0; out_c < output_depth; ++out_c)
+    {
+      int32_t acc = 0;
+      for (int d = 0; d < accum_depth; ++d)
+      {
+        int32_t input_val = input_data[b * accum_depth + d];
+        int32_t weights_val = weights_data[d * output_depth + out_c];
+        acc += (weights_val + weights_offset) * (input_val + input_offset);
+      }
+      acc += bias_data[out_c];
+      acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
+      acc += output_offset;
+      acc = std::max(acc, output_min);
+      acc = std::min(acc, output_max);
+      output_data[out_c + output_depth * b] = static_cast<uint8_t>(acc);
+    }
+  }
+
+  return {res};
+}
+
 } // namespace nnc
index 27b3973..016a1c3 100644 (file)
@@ -39,6 +39,11 @@ private:
   const mir::TensorVariant &_weights;
 };
 
+std::vector<mir::TensorVariant> QuantizedFC(const mir::ops::FullyConnectedOp &op,
+                                            const mir::TensorVariant &input,
+                                            const mir::TensorVariant &weights,
+                                            const mir::TensorVariant &bias);
+
 } // namespace nnc
 
 #endif //_NNC_CORE_BACKEND_INTERPRETER_FULLYCONNECTED_
index 5ad6c5e..e484521 100644 (file)
@@ -74,4 +74,62 @@ std::vector<TensorVariant> MaxPool2D::operator()()
   return {res};
 }
 
+std::vector<mir::TensorVariant> QuantizedMaxPool2D(const mir::ops::MaxPool2DOp &op,
+                                                   const mir::TensorVariant &input)
+{
+  const auto &input_type = input.getType();
+  const auto &output_type = op.getOutput(0)->getType();
+
+  assert(input_type.isQuantized());
+  assert(output_type.isQuantized());
+  assert(input_type.getElementType() == DataType::UINT8);
+
+  const auto &input_shape = op.getInputShape(0);
+  const auto &output_shape = op.getOutputShape(0);
+  const auto &window_size = op.getWindowSize();
+  const auto &strides = op.getStrides();
+  const auto &padding_before = op.getPaddingBefore();
+  const auto &padding_after = op.getPaddingAfter();
+
+  constexpr int num_spatial_dims = 2;
+  assert(input.getShape().rank() == 4);
+  assert(window_size.size() == num_spatial_dims);
+  assert(strides.size() == num_spatial_dims);
+  assert(padding_before.size() == num_spatial_dims);
+  assert(padding_after.size() == num_spatial_dims);
+
+  Tensor<uint8_t> input_accessor(input);
+
+  TensorType res_type(mir::DataType::UINT8, output_shape, output_type.getQuantization());
+  TensorVariant res(res_type);
+  Tensor<uint8_t> res_accessor(res);
+
+  ShapeRange in_range(input_shape);
+  Index in_index(input_shape.rank());
+
+  for (const auto &out_index : ShapeRange(output_shape))
+  {
+    // Assuming NHWC format.
+    in_index.at(0) = out_index.at(0);
+    in_index.at(3) = out_index.at(3);
+
+    uint8_t result = 0;
+    for (const auto &window_index : ShapeRange(Shape(window_size)))
+    {
+      // Assuming NHWC format.
+      for (int i = 0; i < num_spatial_dims; ++i)
+        in_index.at(1 + i) =
+            out_index.at(1 + i) * strides[i] + window_index.at(i) - padding_before[i];
+
+      if (in_range.contains(in_index))
+      {
+        result = std::max(result, input_accessor.at(in_index));
+      }
+    }
+    res_accessor.at(out_index) = result;
+  }
+
+  return {res};
+}
+
 } // namespace nnc
index 9950af1..da7ba5f 100644 (file)
@@ -39,6 +39,9 @@ private:
   const mir::Tensor<float> _input;
 };
 
+std::vector<mir::TensorVariant> QuantizedMaxPool2D(const mir::ops::MaxPool2DOp &op,
+                                                   const mir::TensorVariant &input);
+
 } // namespace nnc
 
 #endif //_NNC_CORE_BACKEND_INTERPRETER_MAX_POOL_2D_
diff --git a/compiler/nnc/backends/interpreter/ops/QuantizationHelpers.cpp b/compiler/nnc/backends/interpreter/ops/QuantizationHelpers.cpp
deleted file mode 100644 (file)
index 093affe..0000000
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "QuantizationHelpers.h"
-
-#include <cmath>
-#include <limits>
-
-namespace nnc
-{
-
-void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
-{
-  if (double_multiplier == 0.)
-  {
-    *quantized_multiplier = 0;
-    *shift = 0;
-    return;
-  }
-
-  const double q = std::frexp(double_multiplier, shift);
-  auto q_fixed = static_cast<int64_t>(round(q * (1ll << 31)));
-
-  assert(q_fixed <= (1ll << 31));
-  if (q_fixed == (1ll << 31))
-  {
-    q_fixed /= 2;
-    ++*shift;
-  }
-  assert(q_fixed <= std::numeric_limits<int32_t>::max());
-  // A shift amount smaller than -31 would cause all bits to be shifted out
-  // and thus all results would be zero. We implement that instead with
-  // q_fixed==0, so as to avoid hitting issues with right-shift
-  // operations with shift amounts greater than 31. Note that this happens
-  // roughly when abs(double_multiplier) < 2^-31 and the present handling means
-  // that we're effectively flushing tiny double_multiplier's to zero.
-  // We could conceivably handle values in the range (roughly) [32, 63]
-  // as 'denormals' i.e. (shift==0, q_fixed < 2^30). In that point of view
-  // the present handling is just doing 'flush denormals to zero'. We could
-  // reconsider and actually generate nonzero denormals if a need arises.
-  if (*shift < -31)
-  {
-    *shift = 0;
-    q_fixed = 0;
-  }
-  *quantized_multiplier = static_cast<int32_t>(q_fixed);
-}
-
-void QuantizeMultiplierSmallerThanOneExp(double double_multiplier, int32_t *quantized_multiplier,
-                                         int *left_shift)
-{
-  assert(double_multiplier < 1.0);
-  assert(double_multiplier > 0.0);
-  int shift;
-  QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift);
-  assert(shift <= 0);
-  *left_shift = shift;
-}
-
-int32_t MaskIfNonZero(int32_t a)
-{
-  static const int32_t zero = 0;
-  return a ? ~zero : zero;
-}
-
-int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); }
-
-int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero(a < b); }
-
-int32_t MaskIfGreaterThan(int32_t a, int32_t b) { return MaskIfNonZero(a > b); }
-
-inline int32_t RoundingDivideByPOT(int32_t x, int exponent)
-{
-  assert(exponent >= 0);
-  assert(exponent <= 31);
-  const int32_t mask = (1ll << exponent) - 1;
-  const int32_t remainder = x & mask;
-  const int32_t threshold = (mask >> 1) + (MaskIfLessThan(x, 0) & 1);
-  return (x >> exponent) + (MaskIfGreaterThan(remainder, threshold) & 1);
-}
-
-inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, std::int32_t b)
-{
-  bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
-  std::int64_t a_64(a);
-  std::int64_t b_64(b);
-  std::int64_t ab_64 = a_64 * b_64;
-  std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
-  std::int32_t ab_x2_high32 = static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
-  return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
-}
-
-int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
-{
-  int left_shift = shift > 0 ? shift : 0;
-  int right_shift = shift > 0 ? 0 : -shift;
-  return RoundingDivideByPOT(
-      SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier), right_shift);
-}
-
-int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x, int32_t quantized_multiplier,
-                                                       int left_shift)
-{
-  return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(x, quantized_multiplier),
-                             -left_shift);
-}
-
-} // namespace nnc
index 0dd9aea..149e46f 100644 (file)
@@ -1,5 +1,6 @@
 /*
  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 #ifndef _NNC_CORE_BACKEND_INTERPRETER_QUANTIZATION_HELPERS_
 #define _NNC_CORE_BACKEND_INTERPRETER_QUANTIZATION_HELPERS_
 
-#include "mir/TensorType.h"
+#include <cmath>
+#include <limits>
 
 namespace nnc
 {
 
-void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift);
+inline void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
+{
+  if (double_multiplier == 0.)
+  {
+    *quantized_multiplier = 0;
+    *shift = 0;
+    return;
+  }
+
+  const double q = std::frexp(double_multiplier, shift);
+  auto q_fixed = static_cast<int64_t>(round(q * (1ll << 31)));
+
+  assert(q_fixed <= (1ll << 31));
+  if (q_fixed == (1ll << 31))
+  {
+    q_fixed /= 2;
+    ++*shift;
+  }
+  assert(q_fixed <= std::numeric_limits<int32_t>::max());
+  // A shift amount smaller than -31 would cause all bits to be shifted out
+  // and thus all results would be zero. We implement that instead with
+  // q_fixed==0, so as to avoid hitting issues with right-shift
+  // operations with shift amounts greater than 31. Note that this happens
+  // roughly when abs(double_multiplier) < 2^-31 and the present handling means
+  // that we're effectively flushing tiny double_multiplier's to zero.
+  // We could conceivably handle values in the range (roughly) [32, 63]
+  // as 'denormals' i.e. (shift==0, q_fixed < 2^30). In that point of view
+  // the present handling is just doing 'flush denormals to zero'. We could
+  // reconsider and actually generate nonzero denormals if a need arises.
+  if (*shift < -31)
+  {
+    *shift = 0;
+    q_fixed = 0;
+  }
+  *quantized_multiplier = static_cast<int32_t>(q_fixed);
+}
+
+inline void QuantizeMultiplierSmallerThanOneExp(double double_multiplier,
+                                                int32_t *quantized_multiplier, int *left_shift)
+{
+  assert(double_multiplier < 1.0);
+  assert(double_multiplier > 0.0);
+  int shift;
+  QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift);
+  assert(shift <= 0);
+  *left_shift = shift;
+}
+
+inline int32_t MaskIfNonZero(int32_t a)
+{
+  static const int32_t zero = 0;
+  return a ? ~zero : zero;
+}
+
+inline int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); }
 
-int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift);
+inline int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero(a < b); }
 
-void QuantizeMultiplierSmallerThanOneExp(double double_multiplier, int32_t *quantized_multiplier,
-                                         int *left_shift);
+inline int32_t MaskIfGreaterThan(int32_t a, int32_t b) { return MaskIfNonZero(a > b); }
 
-int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x, int32_t quantized_multiplier,
-                                                       int left_shift);
+inline int32_t RoundingDivideByPOT(int32_t x, int exponent)
+{
+  assert(exponent >= 0);
+  assert(exponent <= 31);
+  const int32_t mask = (1ll << exponent) - 1;
+  const int32_t remainder = x & mask;
+  const int32_t threshold = (mask >> 1) + (MaskIfLessThan(x, 0) & 1);
+  return (x >> exponent) + (MaskIfGreaterThan(remainder, threshold) & 1);
+}
+
+inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, std::int32_t b)
+{
+  bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
+  std::int64_t a_64(a);
+  std::int64_t b_64(b);
+  std::int64_t ab_64 = a_64 * b_64;
+  std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
+  std::int32_t ab_x2_high32 = static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
+  return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
+}
+
+inline int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
+{
+  int left_shift = shift > 0 ? shift : 0;
+  int right_shift = shift > 0 ? 0 : -shift;
+  return RoundingDivideByPOT(
+      SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier), right_shift);
+}
+
+inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x,
+                                                              int32_t quantized_multiplier,
+                                                              int left_shift)
+{
+  return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(x, quantized_multiplier),
+                             -left_shift);
+}
 
 } // namespace nnc
 
index 1b4ba1d..2652bfe 100644 (file)
@@ -1,5 +1,6 @@
 /*
  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,6 +16,7 @@
  */
 
 #include "Softmax.h"
+#include "QuantizationHelpers.h"
 
 #include <mir/ShapeRange.h>
 #include <mir/Tensor.h>
@@ -57,4 +59,78 @@ void Softmax(const mir::TensorVariant &arg, int axis, mir::TensorVariant &result
   }
 }
 
+inline void PopulateSoftmaxLookupTable(float *table, float input_scale, float beta)
+{
+  const float scale = -input_scale * beta;
+  const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
+  for (int32_t val = 0; val <= max_uint8; ++val)
+    table[max_uint8 - val] = expf(scale * val);
+}
+
+void QuantizedSoftmax(const mir::TensorVariant &input, int axis, mir::TensorVariant &output)
+{
+  const auto &input_type = input.getType();
+  const auto &output_type = output.getType();
+
+  assert(input_type.isQuantized());
+  assert(output_type.isQuantized());
+
+  const auto input_shape = input_type.getShape();
+
+  assert(input_type.getElementType() == mir::DataType::UINT8);
+  assert(axis == input_shape.rank() - 1); // supported only last dim axis
+
+  double input_scale = input_type.getQuantization().getScale();
+  double output_scale = output_type.getQuantization().getScale();
+
+  const int trailing_dim = input_shape.rank() - 1;
+  int excluding_last_dim = 1;
+  for (int32_t i = 0; i < input_shape.rank() - 1; i++)
+  {
+    excluding_last_dim *= input_shape.dim(i);
+  }
+  const int last_dim = input_shape.dim(trailing_dim);
+
+  const int32_t clamp_max = std::numeric_limits<uint8_t>::max();
+  const int32_t clamp_min = std::numeric_limits<uint8_t>::min();
+
+  uint8_t *input_data = reinterpret_cast<uint8_t *>(input.atOffset(0));
+
+  float table[256];
+  PopulateSoftmaxLookupTable(table, input_scale, 1.f);
+
+  uint8_t *output_data = reinterpret_cast<uint8_t *>(output.atOffset(0));
+
+  for (int i = 0; i < excluding_last_dim; ++i)
+  {
+    int32_t max_val = std::numeric_limits<uint8_t>::min();
+    // Find max quantized value.
+    for (int j = 0; j < last_dim; ++j)
+    {
+      max_val = std::max(max_val, static_cast<int32_t>(input_data[j]));
+    }
+
+    float sum_exp = 0.0f;
+    const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
+    const float *table_offset = &table[max_uint8 - max_val];
+    // Calculate normalizer sum(exp(x)).
+    for (int j = 0; j < last_dim; ++j)
+    {
+      sum_exp += table_offset[input_data[j]];
+    }
+
+    const float inv_sum_exp = 1.0f / (sum_exp * output_scale);
+    // Normalize and quantize probabilities.
+    for (int j = 0; j < last_dim; ++j)
+    {
+      const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
+      const int32_t prob_quantized = static_cast<int32_t>(prob_rescaled + 0.5);
+      output_data[j] =
+          static_cast<uint8_t>(std::max(std::min(clamp_max, prob_quantized), clamp_min));
+    }
+    input_data += last_dim;
+    output_data += last_dim;
+  }
+}
+
 } // namespace nnc
index e51973c..23c38d2 100644 (file)
@@ -24,6 +24,8 @@ namespace nnc
 
 void Softmax(const mir::TensorVariant &arg, int axis, mir::TensorVariant &result);
 
+void QuantizedSoftmax(const mir::TensorVariant &input, int axis, mir::TensorVariant &output);
+
 } // namespace nnc
 
 #endif //_NNC_CORE_BACKEND_INTERPRETER_SOFTMAX_