Introduce cpu quant8 softmax kernel (#4953)
author오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 10 Apr 2019 05:40:53 +0000 (14:40 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 10 Apr 2019 05:40:53 +0000 (14:40 +0900)
Introduce cpu quantized int8 softmax kernel from tflite and gemmlowp
Use kernel in neurun cpu backend

Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
libs/cker/include/cker/FixedPoint.h
libs/cker/include/cker/Utils.h
libs/cker/include/cker/operation/SoftMax.h
runtimes/neurun/backend/cpu/kernel/SoftMaxLayer.cc

index 33178a9..653a56d 100644 (file)
@@ -26,6 +26,15 @@ namespace nnfw
 namespace cker
 {
 
+inline int32_t RoundingHalfSum(int32_t a, int32_t b)
+{
+  int64_t a64 = a;
+  int64_t b64 = b;
+  int64_t sum = a64 + b64;
+  int64_t sign = sum >= 0 ? 1 : -1;
+  return static_cast<int32_t>((sum + sign) / 2);
+}
+
 inline int32_t SaturatingRoundingDoublingHighMul(int32_t a, int32_t b)
 {
   bool overflow = a == b && a == std::numeric_limits<int32_t>::min();
@@ -47,8 +56,230 @@ inline int32_t RoundingDivideByPOT(int32_t x, int exponent)
   const int32_t zero = 0;
   const int32_t one = 1;
   const int32_t remainder = x & mask;
-  const int32_t threshold = (mask >> 1) + ((x < zero ? ~zero : zero) & one);
-  return ((x >> exponent) + (((remainder > threshold) ? ~zero : zero) & one));
+  const int32_t threshold = (mask >> 1) + ((x < zero) ? one : zero);
+  return ((x >> exponent) + ((remainder > threshold) ? one : zero));
+}
+
+// Returns the product of a run-time integer value by a compile-time power
+// of two, with either a positive exponent (equivalent to an arithmetic
+// left shift, saturating) or a negative exponent (equivalent to an arithmetic
+// right shift, rounding to nearest).
+template <int Exponent, int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
+struct ImplSaturatingRoundingMultiplyByPOT
+{
+};
+
+template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, 0>
+{
+  static int32_t eval(int32_t x) { return x; }
+};
+
+template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, 1>
+{
+  static int32_t eval(int32_t x)
+  {
+    const int32_t min = (std::numeric_limits<int32_t>::min());
+    const int32_t max = (std::numeric_limits<int32_t>::max());
+    const int32_t threshold = ((1 << (31 - Exponent)) - 1);
+    const int32_t zero = 0;
+    const int32_t one = 1;
+
+    const int32_t positive_mask = ((x > threshold) ? ~zero : zero);
+    const int32_t negative_mask = ((x < -threshold) ? ~zero : zero);
+
+    int32_t result = (x * (one << Exponent));
+    result = (positive_mask ? max : result);
+    result = (negative_mask ? min : result);
+    return result;
+  }
+};
+
+template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, -1>
+{
+  static int32_t eval(int32_t x) { return RoundingDivideByPOT(x, -Exponent); }
+};
+
+template <int Exponent> int32_t SaturatingRoundingMultiplyByPOT(int32_t x)
+{
+  return ImplSaturatingRoundingMultiplyByPOT<Exponent>::eval(x);
+}
+
+template <int tIntegerBits> class FixedPoint
+{
+public:
+  static constexpr int kTotalBits = 8 * sizeof(int32_t);
+  static constexpr int kIntegerBits = tIntegerBits;
+  static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
+  static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits");
+
+  static const int32_t ScalarRawMax() { return std::numeric_limits<int32_t>::max(); }
+
+  static FixedPoint FromRaw(int32_t x)
+  {
+    FixedPoint retval;
+    retval.raw() = x;
+    return retval;
+  }
+
+  static FixedPoint FromScalarRaw(int32_t x) { return FromRaw(x); }
+
+  template <int Exponent> static FixedPoint ConstantPOT()
+  {
+    static constexpr int kOffset = kFractionalBits + Exponent;
+    static_assert(kOffset < 31, "Constant not exactly representable in this fixed-point format");
+    return FromScalarRaw((int32_t)1 << kOffset);
+  }
+
+  static FixedPoint Zero() { return FromScalarRaw(0); }
+
+  static FixedPoint One()
+  {
+    return FromScalarRaw(kIntegerBits == 0 ? ScalarRawMax() : ((int32_t)1 << kFractionalBits));
+  }
+
+  int32_t raw() const { return i_; }
+  int32_t &raw() { return i_; }
+
+private:
+  int32_t i_;
+};
+
+// A FixedPoint multiplication is just a
+// SaturatingRoundingDoublingHighMul operation on the underlying
+// raw integer values. The IntegerBits simply add up, as is obvious
+// from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
+template <int tIntegerBits_a, int tIntegerBits_b>
+FixedPoint<tIntegerBits_a + tIntegerBits_b> operator*(FixedPoint<tIntegerBits_a> a,
+                                                      FixedPoint<tIntegerBits_b> b)
+{
+  FixedPoint<tIntegerBits_a + tIntegerBits_b> c;
+  c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
+  return c;
+}
+
+// Tweaking IntegerBits gives exact multiplication by a power of two.
+template <int tExponent, int tIntegerBits>
+FixedPoint<tExponent + tIntegerBits> ExactMulByPot(FixedPoint<tIntegerBits> a)
+{
+  FixedPoint<tExponent + tIntegerBits> c;
+  c.raw() = a.raw();
+  return c;
+}
+
+template <int tIntegerBits>
+FixedPoint<tIntegerBits> operator+(FixedPoint<tIntegerBits> a, FixedPoint<tIntegerBits> b)
+{
+  return FixedPoint<tIntegerBits>::FromRaw((a.raw() + b.raw()));
+}
+template <int tIntegerBits>
+FixedPoint<tIntegerBits> operator-(FixedPoint<tIntegerBits> a, FixedPoint<tIntegerBits> b)
+{
+  return FixedPoint<tIntegerBits>::FromRaw((a.raw() - b.raw()));
+}
+template <int tIntegerBits>
+FixedPoint<tIntegerBits> operator&(FixedPoint<tIntegerBits> a, FixedPoint<tIntegerBits> b)
+{
+  return FixedPoint<tIntegerBits>::FromRaw((a.raw() & b.raw()));
+}
+
+// Rescale changes the number of IntegerBits and updates the underlying
+// raw integer value accordingly.
+template <int tIntegerBitsDst, int tIntegerBitsSrc>
+FixedPoint<tIntegerBitsDst> Rescale(FixedPoint<tIntegerBitsSrc> x)
+{
+  static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
+  FixedPoint<tIntegerBitsDst> result;
+  result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
+  return result;
+}
+
+// Implementation of exponential function.
+
+// Returns exp(x) for x in [-1/4, 0).
+inline FixedPoint<0> exp_on_interval_between_negative_one_quarter_and_0_excl(FixedPoint<0> a)
+{
+  typedef FixedPoint<0> F;
+  const F constant_term = F::FromScalarRaw(RoundingDivideByPOT(1895147668, 0));
+  const F constant_1_over_3 = F::FromScalarRaw(RoundingDivideByPOT(715827883, 0));
+  // We're evaluating a Taylor expansion around -1/8, so we do the change of
+  // variable: x = a + 1/8.
+  // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
+  F x = a + F::template ConstantPOT<-3>();
+  F x2 = x * x;
+  F x3 = x2 * x;
+  F x4 = x2 * x2;
+  F x4_over_4 = F::FromScalarRaw(SaturatingRoundingMultiplyByPOT<-2>(x4.raw()));
+  F x4_over_24_plus_x3_over_6_plus_x2_over_2 = F::FromScalarRaw(
+      SaturatingRoundingMultiplyByPOT<-1>((((x4_over_4 + x3) * constant_1_over_3) + x2).raw()));
+  return (constant_term + constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
+}
+
+// Returns exp(x) for x < 0.
+template <int tIntegerBits> FixedPoint<0> exp_on_negative_values(FixedPoint<tIntegerBits> a)
+{
+  typedef FixedPoint<tIntegerBits> InputF;
+  typedef FixedPoint<0> ResultF;
+  static constexpr int kFractionalBits = InputF::kFractionalBits;
+  static constexpr int kIntegerBits = InputF::kIntegerBits;
+  const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
+  InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
+  InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
+  ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
+      Rescale<0>(a_mod_quarter_minus_one_quarter));
+  int32_t remainder = (a_mod_quarter_minus_one_quarter - a).raw();
+
+  const int32_t zero = 0;
+
+#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier)                 \
+  if (kIntegerBits > Exponent)                                                      \
+  {                                                                                 \
+    const ResultF kMultiplier =                                                     \
+        ResultF::FromScalarRaw(RoundingDivideByPOT(FixedPointMultiplier, 0));       \
+    static constexpr int kShiftAmount =                                             \
+        ((kIntegerBits > Exponent) ? (kFractionalBits + Exponent) : 0);             \
+    result = ((remainder & (1 << kShiftAmount)) ? (result * kMultiplier) : result); \
+  }
+
+  GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
+  GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
+  GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
+  GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
+  GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
+  GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
+  GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
+
+#undef GEMMLOWP_EXP_BARREL_SHIFTER
+
+  static constexpr int clampB = ((kIntegerBits > 5) ? (36 - kIntegerBits) : 0);
+  if (kIntegerBits > 5)
+  {
+    const InputF clamp = InputF::FromScalarRaw(RoundingDivideByPOT(-(1 << clampB), 0));
+    result.raw() = ((a.raw() < clamp.raw()) ? ResultF::Zero().raw() : result.raw());
+  }
+
+  result.raw() = (a.raw() ? result.raw() : ResultF::One().raw());
+  return result;
+}
+
+// Returns 1 / (1 + x) for x in (0, 1).
+inline FixedPoint<0> one_over_one_plus_x_for_x_in_0_1(FixedPoint<0> a)
+{
+  typedef FixedPoint<0> F0;
+  typedef FixedPoint<2> F2;
+  F0 half_denominator = F0::FromScalarRaw(RoundingHalfSum(a.raw(), F0::One().raw()));
+  // Newton-Raphson division
+  // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
+  // Refer to that page for the logic behind the 48/17 and 32/17 constants.
+  const F2 constant_48_over_17 = F2::FromScalarRaw(RoundingDivideByPOT(1515870810, 0));
+  const F2 constant_neg_32_over_17 = F2::FromScalarRaw(RoundingDivideByPOT(-1010580540, 0));
+  F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
+  for (int i = 0; i < 3; i++)
+  {
+    F2 half_denominator_times_x = half_denominator * x;
+    F2 one_minus_half_denominator_times_x = F2::One() - half_denominator_times_x;
+    x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
+  }
+  return Rescale<0>(ExactMulByPot<-1>(x));
 }
 
 } // namespace cker
index 043cb5c..af98fd8 100644 (file)
@@ -19,6 +19,7 @@
 #define __NNFW_CKER_UTILS_H__
 
 #include <algorithm>
+#include <cstdint>
 
 #include "cker/FixedPoint.h"
 
@@ -41,6 +42,24 @@ inline int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multip
       SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier), right_shift);
 }
 
+inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne(int32_t x, int32_t quantized_multiplier,
+                                                           int left_shift)
+{
+  return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier);
+}
+
+inline int CountLeadingZeros(uint32_t integer_input)
+{
+  const uint32_t one_in_leading_positive = 1U << 31;
+  int leading_zeros = 0;
+  while (integer_input < one_in_leading_positive)
+  {
+    integer_input <<= 1;
+    ++leading_zeros;
+  }
+  return leading_zeros;
+}
+
 } // namespace cker
 } // namespace nnfw
 
index 3acd6b1..d3082f7 100644 (file)
@@ -19,6 +19,8 @@
 #define __NNFW_CKER_SOFTMAX_H__
 
 #include "cker/Shape.h"
+#include "cker/Utils.h"
+#include "cker/FixedPoint.h"
 
 #include <cmath>
 
@@ -74,6 +76,88 @@ inline void Softmax(const SoftmaxParams &params, const Shape &input_shape, const
   }
 }
 
+inline void Softmax(const SoftmaxParams &params, const Shape &input_shape,
+                    const uint8_t *input_data, const Shape &output_shape, uint8_t *output_data)
+{
+  const int32_t input_beta_multiplier = params.input_multiplier;
+  const int32_t input_beta_left_shift = params.input_left_shift;
+  const int diff_min = params.diff_min;
+  // The representation chosen for the input to the exp() function is Q5.26.
+  // We need to leave extra space since values that we skip might be as large as
+  // -32 before multiplying by input_beta_multiplier, and therefore as large as
+  // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
+  // accumulation, but exp(-16) definitely is.
+  static const int kScaledDiffIntegerBits = 5;
+  static const int kAccumulationIntegerBits = 12;
+  using FixedPointScaledDiff = FixedPoint<kScaledDiffIntegerBits>;
+  using FixedPointAccum = FixedPoint<kAccumulationIntegerBits>;
+  using FixedPoint0 = FixedPoint<0>;
+
+  const int trailing_dim = input_shape.DimensionsCount() - 1;
+  const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+  const int depth = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+
+  for (int i = 0; i < outer_size; ++i)
+  {
+    uint8_t max_in_row = 0;
+    for (int c = 0; c < depth; ++c)
+    {
+      max_in_row = std::max(max_in_row, input_data[i * depth + c]);
+    }
+
+    FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
+    for (int c = 0; c < depth; ++c)
+    {
+      int32_t input_diff = static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
+      if (input_diff >= diff_min)
+      {
+        const int32_t input_diff_rescaled = MultiplyByQuantizedMultiplierGreaterThanOne(
+            input_diff, input_beta_multiplier, input_beta_left_shift);
+        const FixedPointScaledDiff scaled_diff_f8 =
+            FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+        sum_of_exps =
+            sum_of_exps + Rescale<kAccumulationIntegerBits>(exp_on_negative_values(scaled_diff_f8));
+      }
+    }
+
+    int32_t fixed_sum_of_exps = sum_of_exps.raw();
+    int headroom_plus_one = CountLeadingZeros(static_cast<uint32_t>(fixed_sum_of_exps));
+    // This is the number of bits to the left of the binary point above 1.0.
+    // Consider fixed_sum_of_exps=1.25.  In that case shifted_scale=0.8 and
+    // no later adjustment will be needed.
+    int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
+    int32_t shifted_sum_minus_one =
+        static_cast<int32_t>((static_cast<uint32_t>(fixed_sum_of_exps) << headroom_plus_one) -
+                             (static_cast<uint32_t>(1) << 31));
+
+    FixedPoint0 shifted_scale =
+        one_over_one_plus_x_for_x_in_0_1(FixedPoint0::FromRaw(shifted_sum_minus_one));
+
+    for (int c = 0; c < depth; ++c)
+    {
+      int32_t input_diff = static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
+      if (input_diff >= diff_min)
+      {
+        const int32_t input_diff_rescaled = MultiplyByQuantizedMultiplierGreaterThanOne(
+            input_diff, input_beta_multiplier, input_beta_left_shift);
+        const FixedPointScaledDiff scaled_diff_f8 =
+            FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+
+        FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
+        int32_t unsat_output =
+            RoundingDivideByPOT((shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
+
+        output_data[i * depth + c] = static_cast<uint8_t>(
+            std::max(std::min(unsat_output, static_cast<int32_t>(255)), static_cast<int32_t>(0)));
+      }
+      else
+      {
+        output_data[i * depth + c] = 0;
+      }
+    }
+  }
+}
+
 } // namespace cker
 } // namespace nnfw
 
index 1fd14d8..6e005fb 100644 (file)
@@ -18,7 +18,6 @@
 
 #include <cker/operation/SoftMax.h>
 
-#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "OperationUtils.h"
 
 namespace neurun
@@ -41,7 +40,7 @@ SoftMaxLayer::SoftMaxLayer()
 void Softmax(const float *in, const int input_size, const int batch_size, const float beta,
              float *out)
 {
-  TF_LITE_ASSERT(input_size > 0);
+  assert(input_size > 0);
 
   // For each batch
   for (int b = 0; b < batch_size; b++)
@@ -137,12 +136,12 @@ bool SoftMaxLayer::softmaxQuant8()
   }
   float diff_min = -1.0f * CalculateInputRadius(kScaledDiffIntegerBits, input_left_shift);
 
-  ::tflite::SoftmaxParams op_params;
+  nnfw::cker::SoftmaxParams op_params;
   op_params.input_multiplier = input_multiplier;
   op_params.input_left_shift = input_left_shift;
   op_params.diff_min = diff_min;
-  ::tflite::optimized_ops::Softmax(op_params, convertShapeToTFLiteShape(shapeIn4D), _inputData.u8,
-                                   convertShapeToTFLiteShape(shapeIn4D), _outputData.u8);
+  nnfw::cker::Softmax(op_params, convertShapeToCkerShape(shapeIn4D), _inputData.u8,
+                      convertShapeToCkerShape(shapeIn4D), _outputData.u8);
   return true;
 }