New quantized log(x) for x > 1. Used for LogSoftmax.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 23 May 2018 21:12:40 +0000 (14:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 23 May 2018 21:15:12 +0000 (14:15 -0700)
PiperOrigin-RevId: 197786738

tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h

index 1b4660e..025e282 100644 (file)
@@ -140,6 +140,45 @@ MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
   return MatrixMap<Scalar>(data, rows, cols);
 }
 
+// This is like the template-parameter version, except that the power-of-two is
+// passed as a function parameter. The template version is to be preferred,
+// since some target hardware optimizations depend on the range of the exponent.
+template <typename IntegerType>
+IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
+  if (exponent == 0) {
+    return x;
+  }
+  using ScalarIntegerType =
+      typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
+  const IntegerType min =
+      gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
+  const IntegerType max =
+      gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
+  const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
+
+  const std::int32_t threshold =
+      ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
+  const IntegerType positive_mask =
+      gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
+  const IntegerType negative_mask =
+      gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
+
+  IntegerType result = gemmlowp::ShiftLeft(x, exponent);
+  result = gemmlowp::SelectUsingMask(positive_mask, max, result);
+  result = gemmlowp::SelectUsingMask(negative_mask, min, result);
+  return result;
+}
+
+// This is like the template-parameter version, except that the power-of-two is
+// passed as a function parameter. See raw-integer version for further comments.
+template <typename tRawType, int tIntegerBits>
+gemmlowp::FixedPoint<tRawType, tIntegerBits>
+SaturatingRoundingMultiplyByPOTParam(
+    gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
+  return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
+      SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
+}
+
 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
 // BROADCASTING.
 //
@@ -4559,6 +4598,119 @@ inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
   }
 }
 
+template <int OutputIntegerBits, int InputIntegerBits>
+inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
+log_x_for_x_greater_than_or_equal_to_1_impl(
+    gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
+  // assert(__builtin_clz(0u) >= std::numeric_limits<uint32>::digits - 1);
+  // assert(__builtin_clz(0u) <= std::numeric_limits<uint32>::digits);
+  using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+  // The reason for accumulating the result with an extra bit of headroom is
+  // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
+  // recip_denom will otherwise introduce an error.
+  static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
+  using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>;
+
+  const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 1488522236, std::log(2.0));
+  const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
+  const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 1518500250, std::sqrt(0.5));
+  const FixedPoint0 one_quarter =
+      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
+
+  const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
+  const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
+  const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 1057819769,
+      2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
+  const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
+
+  const FixedPointAccum shifted_quarter =
+      gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
+
+  // Reinterpret the input value as Q0.31, because we will figure out the
+  // required shift "ourselves" instead of using, say, Rescale.
+  FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
+  // z_a_pow_2 = input_integer_bits - z_a_headroom;
+  int z_a_headroom_plus_1 = __builtin_clz(static_cast<uint32>(z_a.raw()));
+  FixedPoint0 r_a_tmp =
+      SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
+  const int32 r_a_raw =
+      SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
+  // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
+  // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
+  //                   InputIntegerBits - z_b_headroom - 0.25);
+  const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
+      FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
+          InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
+      shifted_quarter);
+
+  // z_b is treated like z_a, but premultiplying by sqrt(0.5).
+  FixedPoint0 z_b = z_a * sqrt_half;
+  int z_b_headroom = __builtin_clz(static_cast<uint32>(z_b.raw())) - 1;
+  const int32 r_b_raw =
+      SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
+  const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
+      FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
+          InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
+      shifted_quarter);
+
+  const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
+  const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
+      std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
+
+  const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
+  FixedPoint0 q = r - sqrt_sqrt_half;
+  q = q + q;
+
+  const FixedPoint0 common_sq = q * q;
+  const FixedPoint0 num = q * r + q * common_sq * alpha_n;
+  const FixedPoint0 denom_minus_one_0 =
+      p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
+  const FixedPoint0 recip_denom =
+      one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
+
+  const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
+  return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
+                                              num_scaled * recip_denom);
+}
+
+// Minimum output bits to accommodate log of maximum input range.  It actually
+// does not matter if one considers, say, [-64,64] or [-64,64).
+//
+// For example, run this through Octave:
+// [0:127; ...
+//  ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
+//  ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
+constexpr int min_log_x_output_bits(int input_bits) {
+  return input_bits > 90
+             ? 7
+             : input_bits > 44
+                   ? 6
+                   : input_bits > 21
+                         ? 5
+                         : input_bits > 10
+                               ? 4
+                               : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1;
+}
+
+template <int OutputIntegerBits, int InputIntegerBits>
+inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
+log_x_for_x_greater_than_or_equal_to_1(
+    gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
+  static_assert(
+      OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
+      "Output integer bits must be sufficent to accommodate logs of inputs.");
+  return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
+                                                     InputIntegerBits>(
+      input_val);
+}
+
 // Currently just a copy of the reference code.
 inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
                        int32 input_multiplier, int32 input_left_shift,
@@ -4604,13 +4756,10 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
       }
     }
 
-    // TODO(b/77858996): Implement fixed-point log().
-    // Not a fully-quantized implementation: floating-point log().
-    const float float_log_sum_of_exps =
-        std::log(static_cast<float>(sum_of_exps.raw()) /
-                 (1 << (31 - kAccumulationIntegerBits)));
-    const int32 fixed_log_sum_of_exps = static_cast<int32>(TfLiteRound(
-        float_log_sum_of_exps * (1 << (31 - kScaledDiffIntegerBits))));
+    const int32 fixed_log_sum_of_exps =
+        log_x_for_x_greater_than_or_equal_to_1<kScaledDiffIntegerBits>(
+            sum_of_exps)
+            .raw();
 
     // rescaled_diff_min is smallest representable in
     // Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the
index cd4af48..f6d8d32 100644 (file)
@@ -33,8 +33,139 @@ limitations under the License.
 #include "tensorflow/contrib/lite/kernels/internal/types.h"
 
 namespace tflite {
+
+// TODO(b/77858996): Add these to gemmlowp.
+template <typename IntegerType>
+IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
+  static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+  return a;
+}
+
+template <>
+inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
+  std::int64_t a64 = a;
+  std::int64_t b64 = b;
+  std::int64_t sum = a64 + b64;
+  return static_cast<std::int32_t>(std::min(
+      static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
+      std::max(
+          static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
+          sum)));
+}
+
+template <typename tRawType, int tIntegerBits>
+gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
+    gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
+    gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
+  return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
+      SaturatingAddNonGemmlowp(a.raw(), b.raw()));
+}
+
+template <typename IntegerType>
+IntegerType SaturatingSub(IntegerType a, IntegerType b) {
+  static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+  return a;
+}
+
+template <>
+inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
+  std::int32_t a32 = a;
+  std::int32_t b32 = b;
+  std::int32_t diff = a32 - b32;
+  return static_cast<std::int16_t>(std::min(32767, std::max(-32768, diff)));
+}
+
+template <>
+inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
+  std::int64_t a64 = a;
+  std::int64_t b64 = b;
+  std::int64_t diff = a64 - b64;
+  return static_cast<std::int32_t>(std::min(
+      static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
+      std::max(
+          static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
+          diff)));
+}
+
+template <typename tRawType, int tIntegerBits>
+gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
+    gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
+    gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
+  return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
+      SaturatingSub(a.raw(), b.raw()));
+}
+// End section to be moved to gemmlowp.
+
 namespace reference_ops {
 
+inline int32 MultiplyByQuantizedMultiplierSmallerThanOne(
+    int32 x, int32 quantized_multiplier, int right_shift) {
+  using gemmlowp::RoundingDivideByPOT;
+  using gemmlowp::SaturatingRoundingDoublingHighMul;
+  return RoundingDivideByPOT(
+      SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift);
+}
+
+inline int32 MultiplyByQuantizedMultiplierGreaterThanOne(
+    int32 x, int32 quantized_multiplier, int left_shift) {
+  using gemmlowp::SaturatingRoundingDoublingHighMul;
+  return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
+                                           quantized_multiplier);
+}
+
+template <typename T>
+int CountLeadingZeros(T integer_input) {
+  static_assert(std::is_unsigned<T>::value,
+                "Only unsigned integer types handled.");
+  if (integer_input == 0) {
+    return std::numeric_limits<T>::digits;
+  }
+  const T one_in_leading_positive = static_cast<T>(1)
+                                    << (std::numeric_limits<T>::digits - 1);
+  int leading_zeros = 0;
+  while (integer_input < one_in_leading_positive) {
+    integer_input <<= 1;
+    ++leading_zeros;
+  }
+  return leading_zeros;
+}
+
+template <typename IntegerType>
+IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
+  if (exponent == 0) {
+    return x;
+  }
+  using ScalarIntegerType =
+      typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
+  const IntegerType min =
+      gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
+  const IntegerType max =
+      gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
+  const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
+
+  const std::int32_t threshold =
+      ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
+  const IntegerType positive_mask =
+      gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
+  const IntegerType negative_mask =
+      gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
+
+  IntegerType result = gemmlowp::ShiftLeft(x, exponent);
+  result = gemmlowp::SelectUsingMask(positive_mask, max, result);
+  result = gemmlowp::SelectUsingMask(negative_mask, min, result);
+  return result;
+}
+
+// If we want to leave IntegerBits fixed, then multiplication
+// by a power of two has to be saturating/rounding, not exact anymore.
+template <typename tRawType, int tIntegerBits>
+gemmlowp::FixedPoint<tRawType, tIntegerBits>
+SaturatingRoundingMultiplyByPOTParam(
+    gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
+  return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
+      SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
+}
+
 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
 // BROADCASTING.
 //
@@ -2642,6 +2773,121 @@ inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
   }
 }
 
+// Although currently the name of this function says that it cannot handle
+// values less than 1, in practice it can handle as low as 1/x_max, where
+// x_max is the largest representable input.  In other words, the output range
+// is symmetric.
+template <int OutputIntegerBits, int InputIntegerBits>
+inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
+log_x_for_x_greater_than_or_equal_to_1_impl(
+    gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
+  using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+  // The reason for accumulating the result with an extra bit of headroom is
+  // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
+  // recip_denom will otherwise introduce an error.
+  static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
+  using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>;
+
+  const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 1488522236, std::log(2.0));
+  const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
+  const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 1518500250, std::sqrt(0.5));
+  const FixedPoint0 one_quarter =
+      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
+
+  const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
+  const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
+  const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 1057819769,
+      2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
+  const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+      FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
+
+  const FixedPointAccum shifted_quarter =
+      gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
+
+  // Reinterpret the input value as Q0.31, because we will figure out the
+  // required shift "ourselves" instead of using, say, Rescale.
+  FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
+  // z_a_pow_2 = input_integer_bits - z_a_headroom;
+  int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32>(z_a.raw()));
+  FixedPoint0 r_a_tmp =
+      SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
+  const int32 r_a_raw =
+      SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
+  // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
+  // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
+  //                   InputIntegerBits - z_b_headroom - 0.25);
+  const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
+      FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
+          InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
+      shifted_quarter);
+
+  // z_b is treated like z_a, but premultiplying by sqrt(0.5).
+  FixedPoint0 z_b = z_a * sqrt_half;
+  int z_b_headroom = CountLeadingZeros(static_cast<uint32>(z_b.raw())) - 1;
+  const int32 r_b_raw =
+      SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
+  const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
+      FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
+          InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
+      shifted_quarter);
+
+  const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
+  const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
+      std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
+
+  const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
+  FixedPoint0 q = r - sqrt_sqrt_half;
+  q = q + q;
+
+  const FixedPoint0 common_sq = q * q;
+  const FixedPoint0 num = q * r + q * common_sq * alpha_n;
+  const FixedPoint0 denom_minus_one_0 =
+      p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
+  const FixedPoint0 recip_denom =
+      one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
+
+  const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
+  return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
+                                              num_scaled * recip_denom);
+}
+
+// Minimum output bits to accommodate log of maximum input range.  It actually
+// does not matter if one considers, say, [-64,64] or [-64,64).
+//
+// For example, run this through Octave:
+// [0:127; ...
+//  ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
+//  ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
+constexpr int min_log_x_output_bits(int input_bits) {
+  return input_bits > 90
+             ? 7
+             : input_bits > 44
+                   ? 6
+                   : input_bits > 21
+                         ? 5
+                         : input_bits > 10
+                               ? 4
+                               : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1;
+}
+
+template <int OutputIntegerBits, int InputIntegerBits>
+inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
+log_x_for_x_greater_than_or_equal_to_1(
+    gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
+  static_assert(
+      OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
+      "Output integer bits must be sufficent to accommodate logs of inputs.");
+  return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
+                                                     InputIntegerBits>(
+      input_val);
+}
+
 inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
                        int32 input_multiplier, int32 input_left_shift,
                        int32 reverse_scaling_divisor,
@@ -2684,13 +2930,10 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
       }
     }
 
-    // TODO(b/77858996): Implement fixed-point log().
-    // Not a fully-quantized implementation: floating-point log().
-    const float float_log_sum_of_exps =
-        std::log(static_cast<float>(sum_of_exps.raw()) /
-                 (1 << (31 - kAccumulationIntegerBits)));
-    const int32 fixed_log_sum_of_exps = static_cast<int32>(TfLiteRound(
-        float_log_sum_of_exps * (1 << (31 - kScaledDiffIntegerBits))));
+    const int32 fixed_log_sum_of_exps =
+        log_x_for_x_greater_than_or_equal_to_1<kScaledDiffIntegerBits>(
+            sum_of_exps)
+            .raw();
 
     // rescaled_diff_min is smallest representable in
     // Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the