Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / circle-quantizer / src / CircleQuantizer.cpp
index 5e717d0..1a09a8a 100644 (file)
@@ -43,6 +43,7 @@ void print_exclusive_options(void)
   std::cout << "    --quantize_dequantize_weights" << std::endl;
   std::cout << "    --quantize_with_minmax" << std::endl;
   std::cout << "    --requantize" << std::endl;
+  std::cout << "    --force_quantparam" << std::endl;
 }
 
 void print_version(void)
@@ -63,6 +64,7 @@ int entry(int argc, char **argv)
   const std::string qdqw = "--quantize_dequantize_weights";
   const std::string qwmm = "--quantize_with_minmax";
   const std::string rq = "--requantize";
+  const std::string fq = "--force_quantparam";
 
   const std::string gpd = "--generate_profile_data";
 
@@ -105,6 +107,15 @@ int entry(int argc, char **argv)
           "Two arguments required: input_dtype(int8) "
           "output_dtype(uint8)");
 
+  arser.add_argument(fq)
+    .nargs(3)
+    .type(arser::DataType::STR_VEC)
+    .required(false)
+    .accumulated(true)
+    .help("Write quantization parameters to the specified tensor. "
+          "Three arguments required: tensor_name(string), "
+          "scale(float) zero_point(int)");
+
   arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
   arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
 
@@ -123,10 +134,11 @@ int entry(int argc, char **argv)
   }
 
   {
-    // only one of qdqw, qwmm, rq option can be used
+    // only one of qdqw, qwmm, rq, fq option can be used
     int32_t opt_used = arser[qdqw] ? 1 : 0;
     opt_used += arser[qwmm] ? 1 : 0;
     opt_used += arser[rq] ? 1 : 0;
+    opt_used += arser[fq] ? 1 : 0;
     if (opt_used != 1)
     {
       print_exclusive_options();
@@ -185,6 +197,34 @@ int entry(int argc, char **argv)
     options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
   }
 
+  if (arser[fq])
+  {
+    auto values = arser.get<std::vector<std::vector<std::string>>>(fq);
+
+    std::vector<std::string> tensors;
+    std::vector<std::string> scales;
+    std::vector<std::string> zero_points;
+
+    for (auto const value : values)
+    {
+      if (value.size() != 3)
+      {
+        std::cerr << arser;
+        return 255;
+      }
+
+      tensors.push_back(value[0]);
+      scales.push_back(value[1]);
+      zero_points.push_back(value[2]);
+    }
+
+    options->enable(Algorithms::ForceQuantParam);
+
+    options->params(AlgorithmParameters::Quantize_tensor_names, tensors);
+    options->params(AlgorithmParameters::Quantize_scales, scales);
+    options->params(AlgorithmParameters::Quantize_zero_points, zero_points);
+  }
+
   std::string input_path = arser.get<std::string>("input");
   std::string output_path = arser.get<std::string>("output");