Update fake quant op to support bitwidths in the range 2 to 16, from 2 to 8.
authorRaghuraman Krishnamoorthi <raghuramank@google.com>
Thu, 8 Feb 2018 23:21:22 +0000 (15:21 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Feb 2018 23:26:32 +0000 (15:26 -0800)
PiperOrigin-RevId: 185061307

tensorflow/core/kernels/fake_quant_ops.cc
tensorflow/core/kernels/fake_quant_ops_functor.h

index 68762af8cf1e76211c0229163d9dce44fc0ad153..f5e279eca4c6d3492419a507c7d070613e169b64 100644 (file)
@@ -45,7 +45,7 @@ namespace tensorflow {
 typedef Eigen::ThreadPoolDevice CPUDevice;
 
 namespace {
-bool IsNumBitsValid(int num_bits) { return num_bits >= 2 && num_bits <= 8; }
+bool IsNumBitsValid(int num_bits) { return num_bits >= 2 && num_bits <= 16; }
 }  // namespace
 
 // -----------------------------------------------------------------------------
@@ -65,8 +65,9 @@ class FakeQuantWithMinMaxArgsOp
                                 " >= ", max_));
     int num_bits;
     OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
-    OP_REQUIRES(context, IsNumBitsValid(num_bits),
-                InvalidArgument("num_bits must be between 2 and 8, inclusive"));
+    OP_REQUIRES(
+        context, IsNumBitsValid(num_bits),
+        InvalidArgument("num_bits must be between 2 and 16, inclusive"));
     bool narrow_range;
     OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
     quant_min_ = narrow_range ? 1 : 0;
@@ -104,8 +105,9 @@ class FakeQuantWithMinMaxArgsGradientOp
                                 " >= ", max_));
     int num_bits;
     OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
-    OP_REQUIRES(context, IsNumBitsValid(num_bits),
-                InvalidArgument("num_bits must be between 2 and 8, inclusive"));
+    OP_REQUIRES(
+        context, IsNumBitsValid(num_bits),
+        InvalidArgument("num_bits must be between 2 and 16, inclusive"));
     bool narrow_range;
     OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
     quant_min_ = narrow_range ? 1 : 0;
@@ -175,8 +177,9 @@ class FakeQuantWithMinMaxVarsOp : public OpKernel {
       : OpKernel::OpKernel(context) {
     int num_bits;
     OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
-    OP_REQUIRES(context, IsNumBitsValid(num_bits),
-                InvalidArgument("num_bits must be between 2 and 8, inclusive"));
+    OP_REQUIRES(
+        context, IsNumBitsValid(num_bits),
+        InvalidArgument("num_bits must be between 2 and 16, inclusive"));
     bool narrow_range;
     OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
     quant_min_ = narrow_range ? 1 : 0;
@@ -213,8 +216,9 @@ class FakeQuantWithMinMaxVarsGradientOp : public OpKernel {
       : OpKernel::OpKernel(context) {
     int num_bits;
     OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
-    OP_REQUIRES(context, IsNumBitsValid(num_bits),
-                InvalidArgument("num_bits must be between 2 and 8, inclusive"));
+    OP_REQUIRES(
+        context, IsNumBitsValid(num_bits),
+        InvalidArgument("num_bits must be between 2 and 16, inclusive"));
     bool narrow_range;
     OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
     quant_min_ = narrow_range ? 1 : 0;
@@ -302,8 +306,9 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel {
       : OpKernel::OpKernel(context) {
     int num_bits;
     OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
-    OP_REQUIRES(context, IsNumBitsValid(num_bits),
-                InvalidArgument("num_bits must be between 2 and 8, inclusive"));
+    OP_REQUIRES(
+        context, IsNumBitsValid(num_bits),
+        InvalidArgument("num_bits must be between 2 and 16, inclusive"));
     bool narrow_range;
     OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
     quant_min_ = narrow_range ? 1 : 0;
@@ -348,8 +353,9 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
       : OpKernel::OpKernel(context) {
     int num_bits;
     OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
-    OP_REQUIRES(context, IsNumBitsValid(num_bits),
-                InvalidArgument("num_bits must be between 2 and 8, inclusive"));
+    OP_REQUIRES(
+        context, IsNumBitsValid(num_bits),
+        InvalidArgument("num_bits must be between 2 and 16, inclusive"));
     bool narrow_range;
     OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
     quant_min_ = narrow_range ? 1 : 0;
index 81189866c34819306231edc2073fbdc23fbb9baf..d51acc38ef7e5a865f51ac319a3ad16198714dd9 100644 (file)
@@ -45,16 +45,16 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void Nudge(
   const float quant_max_float = static_cast<float>(quant_max);
   *scale = (max - min) / (quant_max_float - quant_min_float);
   const float zero_point_from_min = quant_min_float - min / *scale;
-  const uint8 nudged_zero_point = [zero_point_from_min, quant_min,
-                                   quant_min_float, quant_max,
-                                   quant_max_float] {
+  const uint16 nudged_zero_point = [zero_point_from_min, quant_min,
+                                    quant_min_float, quant_max,
+                                    quant_max_float] {
     if (zero_point_from_min < quant_min_float) {
-      return static_cast<uint8>(quant_min);
+      return static_cast<uint16>(quant_min);
     }
     if (zero_point_from_min > quant_max_float) {
-      return static_cast<uint8>(quant_max);
+      return static_cast<uint16>(quant_max);
     }
-    return static_cast<uint8>(StdRound(zero_point_from_min));
+    return static_cast<uint16>(StdRound(zero_point_from_min));
   }();
   *nudged_min = (quant_min_float - nudged_zero_point) * (*scale);
   *nudged_max = (quant_max_float - nudged_zero_point) * (*scale);