Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / tests / tools / onert_run / src / onert_run.cc
index 5acb2bb..0bc64bb 100644 (file)
@@ -23,6 +23,7 @@
 #include "nnfw.h"
 #include "nnfw_util.h"
 #include "nnfw_internal.h"
+#include "nnfw_experimental.h"
 #include "randomgen.h"
 #include "rawformatter.h"
 #ifdef RUY_PROFILER
@@ -48,6 +49,33 @@ void overwriteShapeMap(onert_run::TensorShapeMap &shape_map,
     shape_map[i] = shapes[i];
 }
 
+std::string genQuantizedModelPathFromModelPath(const std::string &model_path, bool is_q16)
+{
+  auto const extension_pos = model_path.find(".circle");
+  if (extension_pos == std::string::npos)
+  {
+    std::cerr << "Input model isn't .circle." << std::endl;
+    exit(-1);
+  }
+  auto const qstring = std::string("_quantized_") + (is_q16 ? "q16" : "q8");
+  return model_path.substr(0, extension_pos) + qstring + ".circle";
+}
+
+std::string genQuantizedModelPathFromPackagePath(const std::string &package_path, bool is_q16)
+{
+  auto package_path_without_slash = package_path;
+  if (package_path_without_slash.back() == '/')
+    package_path_without_slash.pop_back();
+  auto package_name_pos = package_path_without_slash.find_last_of('/');
+  if (package_name_pos == std::string::npos)
+    package_name_pos = 0;
+  else
+    package_name_pos++;
+  auto package_name = package_path_without_slash.substr(package_name_pos);
+  auto const qstring = std::string("_quantized_") + (is_q16 ? "q16" : "q8");
+  return package_path_without_slash + "/" + package_name + qstring + ".circle";
+}
+
 int main(const int argc, char **argv)
 {
   using namespace onert_run;
@@ -85,6 +113,37 @@ int main(const int argc, char **argv)
         NNPR_ENSURE_STATUS(nnfw_load_model_from_file(session, args.getPackageFilename().c_str()));
     });
 
+    // Quantize model
+    auto quantize = args.getQuantize();
+    if (!quantize.empty())
+    {
+      NNFW_QUANTIZE_TYPE quantize_type = NNFW_QUANTIZE_TYPE_NOT_SET;
+      if (quantize == "int8")
+        quantize_type = NNFW_QUANTIZE_TYPE_U8_ASYM;
+      if (quantize == "int16")
+        quantize_type = NNFW_QUANTIZE_TYPE_I16_SYM;
+      NNPR_ENSURE_STATUS(nnfw_set_quantization_type(session, quantize_type));
+
+      if (args.getQuantizedModelPath() != "")
+        NNPR_ENSURE_STATUS(
+          nnfw_set_quantized_model_path(session, args.getQuantizedModelPath().c_str()));
+      else
+      {
+        if (args.useSingleModel())
+          NNPR_ENSURE_STATUS(nnfw_set_quantized_model_path(
+            session,
+            genQuantizedModelPathFromModelPath(args.getModelFilename(), quantize == "int16")
+              .c_str()));
+        else
+          NNPR_ENSURE_STATUS(nnfw_set_quantized_model_path(
+            session,
+            genQuantizedModelPathFromPackagePath(args.getPackageFilename(), quantize == "int16")
+              .c_str()));
+      }
+
+      NNPR_ENSURE_STATUS(nnfw_quantize(session));
+    }
+
     char *available_backends = std::getenv("BACKENDS");
     if (available_backends)
       NNPR_ENSURE_STATUS(nnfw_set_available_backends(session, available_backends));