2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "CircleEvalDiff.h"
18 #include "InputDataLoader.h"
19 #include "MetricPrinter.h"
22 #include <foder/FileLoader.h>
23 #include <luci/Importer.h>
30 bool same_shape(const luci::CircleNode *a, const luci::CircleNode *b)
32 if (a->rank() != b->rank())
35 for (uint32_t i = 0; i < a->rank(); i++)
37 if (not(a->dim(i) == b->dim(i)))
44 bool same_dtype(const luci::CircleNode *a, const luci::CircleNode *b)
46 return a->dtype() == b->dtype();
49 std::unique_ptr<luci::Module> import(const std::string &model_path)
51 // Load model from the file
52 foder::FileLoader loader{model_path};
53 std::vector<char> model_data = loader.load();
56 flatbuffers::Verifier verifier{reinterpret_cast<const uint8_t *>(model_data.data()),
58 if (not circle::VerifyModelBuffer(verifier))
60 throw std::runtime_error("Failed to verify circle '" + model_path + "'");
63 auto circle_model = circle::GetModel(model_data.data());
66 throw std::runtime_error("Failed to load '" + model_path + "'");
68 auto module = luci::Importer().importModule(circle_model);
71 throw std::runtime_error("Failed to load '" + model_path + "'");
76 const std::vector<loco::Node *> inputs_of(const luci::Module *module)
78 return loco::input_nodes(module->graph());
81 const std::vector<loco::Node *> outputs_of(const luci::Module *module)
83 return loco::output_nodes(module->graph());
86 void writeDataToFile(const std::string &filename, const char *data, size_t data_size)
88 std::ofstream fs(filename, std::ofstream::binary);
90 throw std::runtime_error("Cannot open file \"" + filename + "\".\n");
91 if (fs.write(data, data_size).fail())
93 throw std::runtime_error("Failed to write data to file \"" + filename + "\".\n");
97 void checkOutputs(const luci::Module *first, const luci::Module *second)
99 const auto first_output = outputs_of(first);
100 const auto second_output = outputs_of(second);
102 if (first_output.size() != second_output.size())
103 throw std::runtime_error("Models have different output counts");
105 for (uint32_t i = 0; i < first_output.size(); i++)
107 const auto first_node = loco::must_cast<luci::CircleNode *>(first_output[i]);
108 const auto second_node = loco::must_cast<luci::CircleNode *>(second_output[i]);
110 if (not same_shape(first_node, second_node))
111 throw std::runtime_error("Output shape mismatch (" + first_node->name() + ", " +
112 second_node->name() + ")");
114 if (not same_dtype(first_node, second_node))
115 throw std::runtime_error("Output dtype mismatch (" + first_node->name() + ", " +
116 second_node->name() + ")");
122 namespace circle_eval_diff
125 std::vector<std::shared_ptr<Tensor>> interpret(const luci::Module *module,
126 const InputDataLoader::Data &data)
128 auto interpreter = std::make_unique<luci_interpreter::Interpreter>(module);
130 auto input_nodes = ::inputs_of(module);
131 auto output_nodes = ::outputs_of(module);
133 for (uint32_t input_idx = 0; input_idx < data.size(); input_idx++)
135 auto input_node = loco::must_cast<const luci::CircleInput *>(input_nodes[input_idx]);
136 assert(input_node->index() == input_idx);
138 auto input_data = data.at(input_idx);
139 interpreter->writeInputTensor(input_node, input_data.buffer(), input_data.byte_size());
142 interpreter->interpret();
144 std::vector<std::shared_ptr<Tensor>> outputs;
145 for (uint32_t output_idx = 0; output_idx < output_nodes.size(); output_idx++)
147 auto output_node = loco::must_cast<const luci::CircleOutput *>(output_nodes[output_idx]);
148 assert(output_node->index() == output_idx);
150 auto tensor = createEmptyTensor(output_node);
151 interpreter->readOutputTensor(output_node, tensor->buffer(), tensor->byte_size());
152 outputs.emplace_back(tensor);
158 CircleEvalDiff::CircleEvalDiff(std::unique_ptr<Context> &&ctx) : _ctx(std::move(ctx))
163 CircleEvalDiff::~CircleEvalDiff() = default;
165 void CircleEvalDiff::init()
167 _first_module = import(_ctx->first_model_path);
168 _second_module = import(_ctx->second_model_path);
170 // Check modules have the same output signature (dtype/shape)
171 // Exception will be thrown if they have different signature
172 checkOutputs(_first_module.get(), _second_module.get());
175 std::unique_ptr<MetricPrinter> metric;
176 for (auto metric : _ctx->metric)
182 _metrics.emplace_back(std::make_unique<MAEPrinter>());
187 _metrics.emplace_back(std::make_unique<MAPEPrinter>());
192 _metrics.emplace_back(std::make_unique<MPEIRPrinter>());
197 _metrics.emplace_back(std::make_unique<TopKMatchPrinter>(1));
202 _metrics.emplace_back(std::make_unique<TopKMatchPrinter>(5));
207 _metrics.emplace_back(std::make_unique<MSEPrinter>());
211 throw std::runtime_error("Unsupported metric.");
213 _metrics.back()->init(_first_module.get(), _second_module.get());
217 void CircleEvalDiff::evalDiff(void) const
219 auto first_input_loader = circle_eval_diff::makeDataLoader(
220 _ctx->first_input_data_path, _ctx->input_format, ::inputs_of(_first_module.get()));
221 auto second_input_loader = circle_eval_diff::makeDataLoader(
222 _ctx->second_input_data_path, _ctx->input_format, ::inputs_of(_second_module.get()));
224 for (uint32_t data_idx = 0; data_idx < first_input_loader->size(); data_idx++)
226 std::cout << "Evaluating " << data_idx << "'th data" << std::endl;
228 auto first_data = first_input_loader->get(data_idx);
229 auto second_data = second_input_loader->get(data_idx);
231 auto first_output = interpret(_first_module.get(), first_data);
232 auto second_output = interpret(_second_module.get(), second_data);
234 for (auto &metric : _metrics)
236 metric->accumulate(first_output, second_output);
239 if (_ctx.get()->output_prefix.empty())
242 for (uint32_t i = 0; i < first_output.size(); i++)
244 auto out = first_output[i];
245 writeDataToFile(_ctx.get()->output_prefix + "." + std::to_string(data_idx) + ".first.output" +
247 (char *)(out->buffer()), out->byte_size());
249 for (uint32_t i = 0; i < second_output.size(); i++)
251 auto out = second_output[i];
252 writeDataToFile(_ctx.get()->output_prefix + "." + std::to_string(data_idx) +
253 ".second.output" + std::to_string(i),
254 (char *)(out->buffer()), out->byte_size());
258 for (auto &metric : _metrics)
260 std::cout << metric.get() << std::endl;
264 } // namespace circle_eval_diff