using tensorflow::DT_BOOL;
using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT16;
using tensorflow::DT_INT32;
using tensorflow::DT_INT64;
using tensorflow::DT_UINT8;
case ArrayDataType::kInt64:
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
break;
+ case ArrayDataType::kInt16:
+ (*placeholder->mutable_attr())["dtype"].set_type(DT_INT16);
+ break;
default:
LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
}
image_input_op->outputs = {dequantized_input_name};
model->operators.emplace(model->operators.begin(), image_input_op);
- CHECK(input_array.final_data_type == ArrayDataType::kUint8);
- input_array.data_type = ArrayDataType::kUint8;
dequantized_input_array.data_type = ArrayDataType::kFloat;
const auto& input_minmax = input_array.GetMinMax();
auto& dequantized_input_minmax = dequantized_input_array.GetOrCreateMinMax();
dequantized_input_minmax = input_minmax;
auto& input_qparams = input_array.GetOrCreateQuantizationParams();
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(input_minmax,
- &input_qparams);
+ input_array.data_type = input_array.final_data_type;
+ if (input_array.data_type == ArrayDataType::kUint8) {
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(input_minmax,
+ &input_qparams);
+ } else if (input_array.data_type == ArrayDataType::kInt16) {
+ GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(input_minmax,
+ &input_qparams);
+ } else {
+ LOG(FATAL) << "unhandled data type";
+ }
transformation->AddMessageF(
"Created %s"
if (parsed_model_flags.arrays_extra_info_file.specified()) {
string arrays_extra_info_file_contents;
- port::file::GetContents(parsed_model_flags.arrays_extra_info_file.value(),
- &arrays_extra_info_file_contents,
- port::file::Defaults());
+ CHECK(port::file::GetContents(
+ parsed_model_flags.arrays_extra_info_file.value(),
+ &arrays_extra_info_file_contents, port::file::Defaults())
+ .ok());
ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
model_flags->mutable_arrays_extra_info());
}
}
bool IsRealValued(toco::ArrayDataType type) {
+ // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used
+ // for quantized real-number values, and no other integer type is ever used
+ // for that. This is dirty, should be resolved as part of a more general push
+ // to more explicitly distinguish between true-integers and
+ // integers used as quantized values representing real numbers.
return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
- type == toco::ArrayDataType::kUint8);
+ type == toco::ArrayDataType::kUint8 ||
+ type == toco::ArrayDataType::kInt16);
}
void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
const FileFormat output_format = toco_flags.output_format();
ArrayDataType type;
- if (toco_flags.has_inference_input_type()) {
+ if (!SupportsQuantization(output_format)) {
+ // Data type is implicitly float for non-quantized formats
+ type = ArrayDataType::kFloat;
+ } else if (toco_flags.has_inference_input_type()) {
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
} else if (toco_flags.has_inference_type()) {
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
- } else if (!SupportsQuantization(output_format)) {
- // Data type is implicitly float for non-quantized formats
- type = ArrayDataType::kFloat;
} else {
// Nothing to do. Data types stay as-is.
return;
}
void Transform(const TocoFlags& toco_flags, Model* model) {
- // Clean up after import.
- SetFinalDataTypeOnInputs(toco_flags, model);
- UseArraysExtraInfo(model);
- FinishBuildingRNNStates(model);
-
const FileFormat output_format = toco_flags.output_format();
const IODataType inference_type = toco_flags.inference_type();
<< "Quantized inference is not allowed with float inputs.";
}
+ // Clean up after import.
+ SetFinalDataTypeOnInputs(toco_flags, model);
+ UseArraysExtraInfo(model, quantize_output);
+ FinishBuildingRNNStates(model);
+
// Remove unused ops before performing any other optimizations. This is to
// stop optimizations from crossing the input/output boundaries. For example
// this will stop BatchNorm fusing if the output node is in between a conv
const float mean_value = input_array_proto.mean_value();
const float std_value = input_array_proto.std_value();
MinMax input_minmax;
- input_minmax.min = (0.f - mean_value) / std_value;
- input_minmax.max = (255.f - mean_value) / std_value;
+ float qmin = 0, qmax = 255;
+ if (input_array.data_type == ArrayDataType::kInt16) {
+ qmin = -32768;
+ qmax = 32767;
+ }
+ input_minmax.min = (qmin - mean_value) / std_value;
+ input_minmax.max = (qmax - mean_value) / std_value;
if (input_array.minmax) {
if (input_array_proto.has_mean_value() ||
input_array_proto.has_std_value()) {
- CHECK(input_minmax == *input_array.minmax)
+ const double width = input_minmax.max - input_minmax.min;
+ const double kMinMaxAllowedDiff = 1e-6 * width;
+ CHECK(std::abs(input_minmax.min - input_array.minmax->min) <
+ kMinMaxAllowedDiff &&
+ std::abs(input_minmax.max - input_array.minmax->max) <
+ kMinMaxAllowedDiff)
<< input_minmax.min << ", " << input_minmax.max
<< " != " << input_array.minmax->min << ", "
<< input_array.minmax->max;
}
}
-void UseArraysExtraInfo(Model* model) {
+void UseArraysExtraInfo(Model* model, bool quantize_output) {
for (const auto& entry : model->flags.arrays_extra_info().entries()) {
if (!model->HasArray(entry.name())) {
continue;
minmax.min = entry.min();
minmax.max = entry.max();
}
- if (entry.has_data_type()) {
+ if (entry.has_data_type() && quantize_output) {
array.final_data_type =
ConvertIODataTypeToArrayDataType(entry.data_type());
}
// already quantized, then case (a) should hold.
void FinishBuildingRNNStates(Model* model);
-void UseArraysExtraInfo(Model* model);
+void UseArraysExtraInfo(Model* model, bool quantize_output);
} // namespace toco