Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / circle-mpqsolver / src / CircleMPQSolver.cpp
index 23e8fd4..12981be 100644 (file)
 #include <vconone/vconone.h>
 #include <luci/CircleExporter.h>
 #include <luci/CircleFileExpContract.h>
-#include <luci/Log.h>
 
 #include "bisection/BisectionSolver.h"
+#include <core/SolverOutput.h>
 
 #include <iostream>
 #include <iomanip>
-#include <chrono>
 
 void print_version(void)
 {
@@ -32,11 +31,23 @@ void print_version(void)
   std::cout << vconone::get_copyright() << std::endl;
 }
 
-int entry(int argc, char **argv)
+int handleAutoAlgorithm(arser::Arser &arser, mpqsolver::bisection::BisectionSolver &solver)
 {
-  LOGGER(l);
+  solver.algorithm(mpqsolver::bisection::BisectionSolver::Algorithm::Auto);
+  auto data_path = arser.get<std::string>("--visq_file");
+  if (data_path.empty())
+  {
+    std::cerr << "ERROR: please provide visq_file for auto mode" << std::endl;
+    return false;
+  }
+  solver.setVisqPath(data_path);
+  return true;
+}
 
+int entry(int argc, char **argv)
+{
   const std::string bisection_str = "--bisection";
+  const std::string save_intermediate_str = "--save_intermediate";
 
   arser::Arser arser("circle-mpqsolver provides light-weight methods for finding a high-quality "
                      "mixed-precision model within a reasonable time.");
@@ -74,6 +85,17 @@ int entry(int argc, char **argv)
 
   arser.add_argument("--output_model").required(true).help("Output quantized model");
 
+  arser.add_argument("--visq_file")
+    .type(arser::DataType::STR)
+    .default_value("")
+    .required(false)
+    .help("*.visq.json file with quantization errors");
+
+  arser.add_argument(save_intermediate_str)
+    .type(arser::DataType::STR)
+    .required(false)
+    .help("path to save intermediate results");
+
   try
   {
     arser.parse(argc, argv);
@@ -104,7 +126,12 @@ int entry(int argc, char **argv)
     std::cerr << "ERROR: quantization ratio must be in [0, 1]" << std::endl;
     return EXIT_FAILURE;
   }
-  auto start = std::chrono::high_resolution_clock::now();
+
+  SolverOutput::get() << ">> Searching mixed precision configuration \n"
+                      << "model:" << input_model_path << "\n"
+                      << "dataset: " << data_path << "\n"
+                      << "input dtype: " << input_dtype << "\n"
+                      << "output dtype: " << output_dtype << "\n";
 
   if (arser[bisection_str])
   {
@@ -116,14 +143,20 @@ int entry(int argc, char **argv)
       auto value = arser.get<std::string>(bisection_str);
       if (value == "auto")
       {
-        solver.algorithm(BisectionSolver::Algorithm::Auto);
+        SolverOutput::get() << "algorithm: bisection (auto)\n";
+        if (!handleAutoAlgorithm(arser, solver))
+        {
+          return EXIT_FAILURE;
+        }
       }
       else if (value == "true")
       {
+        SolverOutput::get() << "algorithm: bisection (Q16AtFront)";
         solver.algorithm(BisectionSolver::Algorithm::ForceQ16Front);
       }
       else if (value == "false")
       {
+        SolverOutput::get() << "algorithm: bisection (Q8AtFront)";
         solver.algorithm(BisectionSolver::Algorithm::ForceQ16Back);
       }
       else
@@ -134,6 +167,18 @@ int entry(int argc, char **argv)
       }
     }
 
+    if (arser[save_intermediate_str])
+    {
+      auto data_path = arser.get<std::string>(save_intermediate_str);
+      if (!data_path.empty())
+      {
+        solver.set_save_intermediate(data_path);
+      }
+    }
+
+    SolverOutput::get() << "qerror metric: MAE\n"
+                        << "target qerror ratio: " << qerror_ratio << "\n";
+
     auto optimized = solver.run(input_model_path);
     if (optimized == nullptr)
     {
@@ -143,6 +188,7 @@ int entry(int argc, char **argv)
 
     // save optimized
     {
+      SolverOutput::get() << "Saving output model to " << output_model_path << "\n";
       luci::CircleExporter exporter;
       luci::CircleFileExpContract contract(optimized.get(), output_model_path);
       if (!exporter.invoke(&contract))
@@ -159,10 +205,5 @@ int entry(int argc, char **argv)
     return EXIT_FAILURE;
   }
 
-  auto duration = std::chrono::duration_cast<std::chrono::seconds>(
-    std::chrono::high_resolution_clock::now() - start);
-  VERBOSE(l, 0) << "Elapsed Time: " << std::setprecision(5) << duration.count() / 60.f
-                << " minutes." << std::endl;
-
   return EXIT_SUCCESS;
 }