#include <vconone/vconone.h>
#include <luci/CircleExporter.h>
#include <luci/CircleFileExpContract.h>
-#include <luci/Log.h>
#include "bisection/BisectionSolver.h"
+#include <core/SolverOutput.h>
#include <iostream>
#include <iomanip>
-#include <chrono>
void print_version(void)
{
std::cout << vconone::get_copyright() << std::endl;
}
-int entry(int argc, char **argv)
+int handleAutoAlgorithm(arser::Arser &arser, mpqsolver::bisection::BisectionSolver &solver)
{
- LOGGER(l);
+ solver.algorithm(mpqsolver::bisection::BisectionSolver::Algorithm::Auto);
+ auto data_path = arser.get<std::string>("--visq_file");
+ if (data_path.empty())
+ {
+ std::cerr << "ERROR: please provide visq_file for auto mode" << std::endl;
+ return false;
+ }
+ solver.setVisqPath(data_path);
+ return true;
+}
+int entry(int argc, char **argv)
+{
const std::string bisection_str = "--bisection";
+ const std::string save_intermediate_str = "--save_intermediate";
arser::Arser arser("circle-mpqsolver provides light-weight methods for finding a high-quality "
"mixed-precision model within a reasonable time.");
arser.add_argument("--output_model").required(true).help("Output quantized model");
+ arser.add_argument("--visq_file")
+ .type(arser::DataType::STR)
+ .default_value("")
+ .required(false)
+ .help("*.visq.json file with quantization errors");
+
+ arser.add_argument(save_intermediate_str)
+ .type(arser::DataType::STR)
+ .required(false)
+ .help("path to save intermediate results");
+
try
{
arser.parse(argc, argv);
std::cerr << "ERROR: quantization ratio must be in [0, 1]" << std::endl;
return EXIT_FAILURE;
}
- auto start = std::chrono::high_resolution_clock::now();
+
+ SolverOutput::get() << ">> Searching mixed precision configuration \n"
+ << "model:" << input_model_path << "\n"
+ << "dataset: " << data_path << "\n"
+ << "input dtype: " << input_dtype << "\n"
+ << "output dtype: " << output_dtype << "\n";
if (arser[bisection_str])
{
auto value = arser.get<std::string>(bisection_str);
if (value == "auto")
{
- solver.algorithm(BisectionSolver::Algorithm::Auto);
+ SolverOutput::get() << "algorithm: bisection (auto)\n";
+ if (!handleAutoAlgorithm(arser, solver))
+ {
+ return EXIT_FAILURE;
+ }
}
else if (value == "true")
{
+ SolverOutput::get() << "algorithm: bisection (Q16AtFront)";
solver.algorithm(BisectionSolver::Algorithm::ForceQ16Front);
}
else if (value == "false")
{
+ SolverOutput::get() << "algorithm: bisection (Q8AtFront)";
solver.algorithm(BisectionSolver::Algorithm::ForceQ16Back);
}
else
}
}
+ if (arser[save_intermediate_str])
+ {
+ auto data_path = arser.get<std::string>(save_intermediate_str);
+ if (!data_path.empty())
+ {
+ solver.set_save_intermediate(data_path);
+ }
+ }
+
+ SolverOutput::get() << "qerror metric: MAE\n"
+ << "target qerror ratio: " << qerror_ratio << "\n";
+
auto optimized = solver.run(input_model_path);
if (optimized == nullptr)
{
// save optimized
{
+ SolverOutput::get() << "Saving output model to " << output_model_path << "\n";
luci::CircleExporter exporter;
luci::CircleFileExpContract contract(optimized.get(), output_model_path);
if (!exporter.invoke(&contract))
return EXIT_FAILURE;
}
- auto duration = std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::high_resolution_clock::now() - start);
- VERBOSE(l, 0) << "Elapsed Time: " << std::setprecision(5) << duration.count() / 60.f
- << " minutes." << std::endl;
-
return EXIT_SUCCESS;
}