#include "luci/Pass/QuantizePreCheckerPass.h"
#include "luci/Pass/QuantizeWithMinMaxPass.h"
#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
+#include "luci/Pass/QuantizeWeightsPass.h"
#include "luci/Pass/CircleShapeInferencePass.h"
#include "luci/Pass/CircleTypeInferencePass.h"
throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
to_string(qwmm_supported_granularity));
- for (auto dtype : input_type_vec)
+ for (const auto &dtype : input_type_vec)
{
if (!in_array(to_lower_case(dtype), qwmm_supported_input_type))
throw std::runtime_error("Unsupported input type. List of supported input types: " +
to_string(qwmm_supported_input_type));
}
- for (auto dtype : output_type_vec)
+ for (const auto &dtype : output_type_vec)
{
if (!in_array(to_lower_case(dtype), qwmm_supported_output_type))
throw std::runtime_error("Unsupported output type. List of supported output types: " +
verifier.verify(g);
}
+ if (_options->query(Options::Algorithm::QuantizeWeights))
+ {
+ static const std::vector<std::string> qw_supported_input_model_dtype{"float32"};
+ static const std::vector<std::string> qw_supported_output_model_dtype{"int8", "int16"};
+ static const std::vector<std::string> qw_supported_granularity{"channel"};
+
+ auto input_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
+ auto output_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
+ auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
+
+ if (!in_array(to_lower_case(input_model_dtype), qw_supported_input_model_dtype))
+ throw std::runtime_error("Unsupported input type. List of supported input type: " +
+ to_string(qw_supported_input_model_dtype));
+
+ if (!in_array(to_lower_case(output_model_dtype), qw_supported_output_model_dtype))
+ throw std::runtime_error("Unsupported output type. List of supported output type: " +
+ to_string(qw_supported_output_model_dtype));
+
+ if (!in_array(to_lower_case(granularity), qw_supported_granularity))
+ throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
+ to_string(qw_supported_granularity));
+ auto ctx = std::make_unique<luci::QuantizeWeightsPass::Context>();
+ {
+ ctx->input_model_dtype = str_to_dtype(input_model_dtype);
+ ctx->output_model_dtype = str_to_dtype(output_model_dtype);
+ ctx->granularity = str_to_granularity(granularity);
+ }
+ luci::QuantizeWeightsPass weights_quantizer(std::move(ctx));
+
+ weights_quantizer.run(g);
+ }
+
// Requantize
if (_options->query(Options::Algorithm::Requantize))
{