From f9a8009db5af7d1019d67f93890e8ad00773d91c Mon Sep 17 00:00:00 2001 From: Raghuraman Krishnamoorthi Date: Thu, 8 Feb 2018 15:21:22 -0800 Subject: [PATCH] Update fake quant op to support bitwidths in the range 2 to 16, from 2 to 8. PiperOrigin-RevId: 185061307 --- tensorflow/core/kernels/fake_quant_ops.cc | 32 ++++++++++++++---------- tensorflow/core/kernels/fake_quant_ops_functor.h | 12 ++++----- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/tensorflow/core/kernels/fake_quant_ops.cc b/tensorflow/core/kernels/fake_quant_ops.cc index 68762af..f5e279e 100644 --- a/tensorflow/core/kernels/fake_quant_ops.cc +++ b/tensorflow/core/kernels/fake_quant_ops.cc @@ -45,7 +45,7 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; namespace { -bool IsNumBitsValid(int num_bits) { return num_bits >= 2 && num_bits <= 8; } +bool IsNumBitsValid(int num_bits) { return num_bits >= 2 && num_bits <= 16; } } // namespace // ----------------------------------------------------------------------------- @@ -65,8 +65,9 @@ class FakeQuantWithMinMaxArgsOp " >= ", max_)); int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(context, IsNumBitsValid(num_bits), - InvalidArgument("num_bits must be between 2 and 8, inclusive")); + OP_REQUIRES( + context, IsNumBitsValid(num_bits), + InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; @@ -104,8 +105,9 @@ class FakeQuantWithMinMaxArgsGradientOp " >= ", max_)); int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(context, IsNumBitsValid(num_bits), - InvalidArgument("num_bits must be between 2 and 8, inclusive")); + OP_REQUIRES( + context, IsNumBitsValid(num_bits), + InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; @@ -175,8 +177,9 @@ class FakeQuantWithMinMaxVarsOp : public OpKernel { : OpKernel::OpKernel(context) { int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(context, IsNumBitsValid(num_bits), - InvalidArgument("num_bits must be between 2 and 8, inclusive")); + OP_REQUIRES( + context, IsNumBitsValid(num_bits), + InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; @@ -213,8 +216,9 @@ class FakeQuantWithMinMaxVarsGradientOp : public OpKernel { : OpKernel::OpKernel(context) { int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(context, IsNumBitsValid(num_bits), - InvalidArgument("num_bits must be between 2 and 8, inclusive")); + OP_REQUIRES( + context, IsNumBitsValid(num_bits), + InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; @@ -302,8 +306,9 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel { : OpKernel::OpKernel(context) { int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(context, IsNumBitsValid(num_bits), - InvalidArgument("num_bits must be between 2 and 8, inclusive")); + OP_REQUIRES( + context, IsNumBitsValid(num_bits), + InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; @@ -348,8 +353,9 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel { : OpKernel::OpKernel(context) { int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(context, IsNumBitsValid(num_bits), - InvalidArgument("num_bits must be between 2 and 8, inclusive")); + OP_REQUIRES( + context, IsNumBitsValid(num_bits), + InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; diff --git a/tensorflow/core/kernels/fake_quant_ops_functor.h b/tensorflow/core/kernels/fake_quant_ops_functor.h index 8118986..d51acc3 100644 --- a/tensorflow/core/kernels/fake_quant_ops_functor.h +++ b/tensorflow/core/kernels/fake_quant_ops_functor.h @@ -45,16 +45,16 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void Nudge( const float quant_max_float = static_cast(quant_max); *scale = (max - min) / (quant_max_float - quant_min_float); const float zero_point_from_min = quant_min_float - min / *scale; - const uint8 nudged_zero_point = [zero_point_from_min, quant_min, - quant_min_float, quant_max, - quant_max_float] { + const uint16 nudged_zero_point = [zero_point_from_min, quant_min, + quant_min_float, quant_max, + quant_max_float] { if (zero_point_from_min < quant_min_float) { - return static_cast(quant_min); + return static_cast(quant_min); } if (zero_point_from_min > quant_max_float) { - return static_cast(quant_max); + return static_cast(quant_max); } - return static_cast(StdRound(zero_point_from_min)); + return static_cast(StdRound(zero_point_from_min)); }(); *nudged_min = (quant_min_float - nudged_zero_point) * (*scale); *nudged_max = (quant_max_float - nudged_zero_point) * (*scale); -- 2.7.4