2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 #include <luci/CircleExporter.h>
20 #include <luci/CircleFileExpContract.h>
26 using namespace mpqsolver::core;
31 const std::string default_dtype_key = "default_quantization_dtype";
32 const std::string default_granularity_key = "default_granularity";
33 const std::string layers_key = "layers";
34 const std::string model_key = "model_path";
35 const std::string layer_name_key = "name";
36 const std::string layer_dtype_key = "dtype";
37 const std::string layer_granularity_key = "granularity";
41 Dumper::Dumper(const std::string &dir_path) : _dir_path(dir_path) {}
43 void Dumper::set_model_path(const std::string &model_path) { _model_path = model_path; }
45 void Dumper::dump_MPQ_configuration(const LayerParams &layers, const std::string &def_dtype,
46 const std::string &path) const
49 mpq_data[default_dtype_key] = def_dtype;
50 mpq_data[default_granularity_key] = "channel";
51 mpq_data[model_key] = _model_path;
53 Json::Value layers_data;
54 for (auto &layer : layers)
56 Json::Value layer_data;
57 layer_data[layer_name_key] = layer->name;
58 layer_data[layer_granularity_key] = layer->granularity;
59 layer_data[layer_dtype_key] = layer->dtype;
60 layers_data.append(layer_data);
62 mpq_data[layers_key] = layers_data;
64 Json::StreamWriterBuilder builder;
65 auto data = Json::writeString(builder, mpq_data);
67 write_data_to_file(path, data);
70 void Dumper::prepare_directory(const std::string &dir_path) const
73 if (stat(dir_path.c_str(), &sb) != 0 || !S_ISDIR(sb.st_mode))
75 if (mkdir(dir_path.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH) != 0)
77 throw std::runtime_error("Failed to create directory for dumping intermediate results");
82 void Dumper::dump_MPQ_configuration(const LayerParams &layers, const std::string &def_dtype,
85 prepare_directory(_dir_path);
86 std::string path = _dir_path + "/Configuration_" + std::to_string(step) + ".mpq.json";
87 dump_MPQ_configuration(layers, def_dtype, path);
90 void Dumper::dump_final_MPQ(const LayerParams &layers, const std::string &def_dtype) const
92 prepare_directory(_dir_path);
93 std::string path = _dir_path + "/FinalConfiguration" + ".mpq.json";
94 dump_MPQ_configuration(layers, def_dtype, path);
97 void Dumper::write_data_to_file(const std::string &path, const std::string &data) const
105 void Dumper::save_circle(luci::Module *module, std::string &path) const
107 luci::CircleExporter exporter;
108 luci::CircleFileExpContract contract(module, path);
109 if (!exporter.invoke(&contract))
111 throw std::runtime_error("Failed to export circle model to " + path);
115 void Dumper::dump_quantized(luci::Module *module, uint32_t step) const
117 std::string path = _dir_path + "/quantized_" + std::to_string(step) + ".mpq.circle";
118 save_circle(module, path);
121 void Dumper::dump_error(float error, const std::string &tag, const std::string &path) const
124 file.open(path, std::ios_base::app);
125 file << tag << " " << error << std::endl;
129 void Dumper::prepare_for_error_dumping() const
131 prepare_directory(_dir_path);
132 std::string path = get_error_path();
134 file.open(path); // create empty
138 void Dumper::dump_Q8_error(float error) const
140 std::string path = get_error_path();
141 dump_error(error, "Q8", path);
144 void Dumper::dump_Q16_error(float error) const
146 std::string path = get_error_path();
147 dump_error(error, "Q16", path);
150 void Dumper::dump_MPQ_error(float error, uint32_t step) const
152 std::string path = get_error_path();
153 dump_error(error, std::to_string(step), path);
156 void Dumper::dump_MPQ_error(float error) const
158 std::string path = get_error_path();
159 dump_error(error, "FINAL", path);