Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / CircleQuantizer.cpp
index 3ffa118..9039a83 100644 (file)
@@ -26,6 +26,7 @@
 #include "luci/Pass/QuantizePreCheckerPass.h"
 #include "luci/Pass/QuantizeWithMinMaxPass.h"
 #include "luci/Pass/QuantizeDequantizeWeightsPass.h"
+#include "luci/Pass/QuantizeWeightsPass.h"
 
 #include "luci/Pass/CircleShapeInferencePass.h"
 #include "luci/Pass/CircleTypeInferencePass.h"
@@ -439,14 +440,14 @@ void CircleQuantizer::quantize(loco::Graph *g) const
       throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
                                to_string(qwmm_supported_granularity));
 
-    for (auto dtype : input_type_vec)
+    for (const auto &dtype : input_type_vec)
     {
       if (!in_array(to_lower_case(dtype), qwmm_supported_input_type))
         throw std::runtime_error("Unsupported input type. List of supported input types: " +
                                  to_string(qwmm_supported_input_type));
     }
 
-    for (auto dtype : output_type_vec)
+    for (const auto &dtype : output_type_vec)
     {
       if (!in_array(to_lower_case(dtype), qwmm_supported_output_type))
         throw std::runtime_error("Unsupported output type. List of supported output types: " +
@@ -536,6 +537,40 @@ void CircleQuantizer::quantize(loco::Graph *g) const
     verifier.verify(g);
   }
 
+  if (_options->query(Options::Algorithm::QuantizeWeights))
+  {
+    static const std::vector<std::string> qw_supported_input_model_dtype{"float32"};
+    static const std::vector<std::string> qw_supported_output_model_dtype{"int8", "int16"};
+    static const std::vector<std::string> qw_supported_granularity{"channel"};
+
+    auto input_model_dtype =
+      _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
+    auto output_model_dtype =
+      _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
+    auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
+
+    if (!in_array(to_lower_case(input_model_dtype), qw_supported_input_model_dtype))
+      throw std::runtime_error("Unsupported input type. List of supported input type: " +
+                               to_string(qw_supported_input_model_dtype));
+
+    if (!in_array(to_lower_case(output_model_dtype), qw_supported_output_model_dtype))
+      throw std::runtime_error("Unsupported output type. List of supported output type: " +
+                               to_string(qw_supported_output_model_dtype));
+
+    if (!in_array(to_lower_case(granularity), qw_supported_granularity))
+      throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
+                               to_string(qw_supported_granularity));
+    auto ctx = std::make_unique<luci::QuantizeWeightsPass::Context>();
+    {
+      ctx->input_model_dtype = str_to_dtype(input_model_dtype);
+      ctx->output_model_dtype = str_to_dtype(output_model_dtype);
+      ctx->granularity = str_to_granularity(granularity);
+    }
+    luci::QuantizeWeightsPass weights_quantizer(std::move(ctx));
+
+    weights_quantizer.run(g);
+  }
+
   // Requantize
   if (_options->query(Options::Algorithm::Requantize))
   {