typedef Eigen::ThreadPoolDevice CPUDevice;
namespace {
-bool IsNumBitsValid(int num_bits) { return num_bits >= 2 && num_bits <= 8; }
+bool IsNumBitsValid(int num_bits) { return num_bits >= 2 && num_bits <= 16; }
} // namespace
// -----------------------------------------------------------------------------
" >= ", max_));
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
- OP_REQUIRES(context, IsNumBitsValid(num_bits),
- InvalidArgument("num_bits must be between 2 and 8, inclusive"));
+ OP_REQUIRES(
+ context, IsNumBitsValid(num_bits),
+ InvalidArgument("num_bits must be between 2 and 16, inclusive"));
bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
" >= ", max_));
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
- OP_REQUIRES(context, IsNumBitsValid(num_bits),
- InvalidArgument("num_bits must be between 2 and 8, inclusive"));
+ OP_REQUIRES(
+ context, IsNumBitsValid(num_bits),
+ InvalidArgument("num_bits must be between 2 and 16, inclusive"));
bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
: OpKernel::OpKernel(context) {
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
- OP_REQUIRES(context, IsNumBitsValid(num_bits),
- InvalidArgument("num_bits must be between 2 and 8, inclusive"));
+ OP_REQUIRES(
+ context, IsNumBitsValid(num_bits),
+ InvalidArgument("num_bits must be between 2 and 16, inclusive"));
bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
: OpKernel::OpKernel(context) {
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
- OP_REQUIRES(context, IsNumBitsValid(num_bits),
- InvalidArgument("num_bits must be between 2 and 8, inclusive"));
+ OP_REQUIRES(
+ context, IsNumBitsValid(num_bits),
+ InvalidArgument("num_bits must be between 2 and 16, inclusive"));
bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
: OpKernel::OpKernel(context) {
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
- OP_REQUIRES(context, IsNumBitsValid(num_bits),
- InvalidArgument("num_bits must be between 2 and 8, inclusive"));
+ OP_REQUIRES(
+ context, IsNumBitsValid(num_bits),
+ InvalidArgument("num_bits must be between 2 and 16, inclusive"));
bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
: OpKernel::OpKernel(context) {
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
- OP_REQUIRES(context, IsNumBitsValid(num_bits),
- InvalidArgument("num_bits must be between 2 and 8, inclusive"));
+ OP_REQUIRES(
+ context, IsNumBitsValid(num_bits),
+ InvalidArgument("num_bits must be between 2 and 16, inclusive"));
bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
const float quant_max_float = static_cast<float>(quant_max);
*scale = (max - min) / (quant_max_float - quant_min_float);
const float zero_point_from_min = quant_min_float - min / *scale;
- const uint8 nudged_zero_point = [zero_point_from_min, quant_min,
- quant_min_float, quant_max,
- quant_max_float] {
+ const uint16 nudged_zero_point = [zero_point_from_min, quant_min,
+ quant_min_float, quant_max,
+ quant_max_float] {
if (zero_point_from_min < quant_min_float) {
- return static_cast<uint8>(quant_min);
+ return static_cast<uint16>(quant_min);
}
if (zero_point_from_min > quant_max_float) {
- return static_cast<uint8>(quant_max);
+ return static_cast<uint16>(quant_max);
}
- return static_cast<uint8>(StdRound(zero_point_from_min));
+ return static_cast<uint16>(StdRound(zero_point_from_min));
}();
*nudged_min = (quant_min_float - nudged_zero_point) * (*scale);
*nudged_max = (quant_max_float - nudged_zero_point) * (*scale);