#include <map>
#include <cpp14/memory.h>
#include <fstream>
+#include <limits>
namespace neurun
{
return model::DataType::INT32;
case TensorType::TensorType_BOOL:
return model::DataType::BOOL8;
- case TensorType::TensorType_INT8:
+ case TensorType::TensorType_UINT8:
return model::DataType::QUANT8_ASYMM;
default:
throw std::runtime_error(
}
// Type
model::DataType data_type = tensorTypeToDataType(tensor->type());
+ // Quantization
+ auto q_params = tensor->quantization();
+ float scale = 0.0;
+ long zero_point = 0;
+ if (q_params != nullptr)
+ {
+ if (q_params->scale())
+ {
+ if (q_params->scale()->size() != 1)
+ {
+ throw std::runtime_error("Only 1 scale for a tensor is supported.");
+ }
+ scale = q_params->scale()->Get(0);
+ }
+
+ if (q_params->zero_point())
+ {
+ if (q_params->zero_point()->size() != 1)
+ {
+ throw std::runtime_error("Only 1 zero_point value for a tensor is supported.");
+ }
+ zero_point = q_params->zero_point()->Get(0);
+ // zero_point is long while TypeInfo.zero_point is defined as int32_t.
+ assert(zero_point >= std::numeric_limits<int32_t>::min());
+ assert(zero_point <= std::numeric_limits<int32_t>::max());
+ }
+ auto details = q_params->details_as_CustomQuantization();
+ if (details != nullptr)
+ throw std::runtime_error("Custom Quantization is not supported");
+ }
// Create TypeInfo
- model::TypeInfo type_info(data_type);
+ model::TypeInfo type_info(data_type, scale, zero_point);
// Create operand
const auto operand_index = _graph.addOperand(shape, type_info);
// Name unused
// auto name = tensor->name();
- // Quantization
- auto quantization = tensor->quantization();
- if (quantization != nullptr)
- {
- auto scale = quantization->scale();
- auto zero_point = quantization->zero_point();
- if (scale != nullptr || zero_point != nullptr)
- throw std::runtime_error("Quantization is not supported!");
-
- auto details = quantization->details_as_CustomQuantization();
- if (details != nullptr)
- throw std::runtime_error("Custom Quantization is not supported");
- }
// Variablie
if (tensor->is_variable())
throw std::runtime_error("Variable tensor not supported!");