static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"};
static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"};
static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
+ static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"};
+ static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"};
auto input_model_dtype =
_options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
auto output_model_dtype =
_options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
+ auto input_type = _options->param(Options::AlgorithmParameters::Quantize_input_type);
+ if (input_type.empty())
+ input_type = output_model_dtype;
+ auto output_type = _options->param(Options::AlgorithmParameters::Quantize_output_type);
+ if (output_type.empty())
+ output_type = output_model_dtype;
if (!in_array(to_lower_case(input_model_dtype), qwmm_supported_input_model_dtype))
throw std::runtime_error("Unsupported input type. List of supported input types: " +
throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
to_string(qwmm_supported_granularity));
+ if (!in_array(to_lower_case(input_type), qwmm_supported_input_type))
+ throw std::runtime_error("Unsupported input type. List of supported input types: " +
+ to_string(qwmm_supported_input_type));
+
+ if (!in_array(to_lower_case(output_type), qwmm_supported_output_type))
+ throw std::runtime_error("Unsupported output type. List of supported output types: " +
+ to_string(qwmm_supported_output_type));
+
if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
str_to_dtype(output_model_dtype) != loco::DataType::U8)
throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
- luci::QuantizeWithMinMaxPass quantizer(str_to_dtype(input_model_dtype),
- str_to_dtype(output_model_dtype),
- str_to_granularity(granularity));
+ luci::QuantizeWithMinMaxPass quantizer(
+ str_to_dtype(input_model_dtype), str_to_dtype(output_model_dtype),
+ str_to_granularity(granularity), str_to_dtype(input_type), str_to_dtype(output_type));
quantizer.run(g);
// Post-quantization optimizations