Imported Upstream version 1.19.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / CircleOptimizer.cpp
index 5d0c926..75f04b3 100644 (file)
@@ -468,12 +468,20 @@ void CircleOptimizer::quantize(loco::Graph *g) const
     static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"};
     static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"};
     static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
+    static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"};
+    static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"};
 
     auto input_model_dtype =
       _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
     auto output_model_dtype =
       _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
     auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
+    auto input_type = _options->param(Options::AlgorithmParameters::Quantize_input_type);
+    if (input_type.empty())
+      input_type = output_model_dtype;
+    auto output_type = _options->param(Options::AlgorithmParameters::Quantize_output_type);
+    if (output_type.empty())
+      output_type = output_model_dtype;
 
     if (!in_array(to_lower_case(input_model_dtype), qwmm_supported_input_model_dtype))
       throw std::runtime_error("Unsupported input type. List of supported input types: " +
@@ -487,13 +495,21 @@ void CircleOptimizer::quantize(loco::Graph *g) const
       throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
                                to_string(qwmm_supported_granularity));
 
+    if (!in_array(to_lower_case(input_type), qwmm_supported_input_type))
+      throw std::runtime_error("Unsupported input type. List of supported input types: " +
+                               to_string(qwmm_supported_input_type));
+
+    if (!in_array(to_lower_case(output_type), qwmm_supported_output_type))
+      throw std::runtime_error("Unsupported output type. List of supported output types: " +
+                               to_string(qwmm_supported_output_type));
+
     if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
         str_to_dtype(output_model_dtype) != loco::DataType::U8)
       throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
 
-    luci::QuantizeWithMinMaxPass quantizer(str_to_dtype(input_model_dtype),
-                                           str_to_dtype(output_model_dtype),
-                                           str_to_granularity(granularity));
+    luci::QuantizeWithMinMaxPass quantizer(
+      str_to_dtype(input_model_dtype), str_to_dtype(output_model_dtype),
+      str_to_granularity(granularity), str_to_dtype(input_type), str_to_dtype(output_type));
     quantizer.run(g);
 
     // Post-quantization optimizations