2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "Evaluator.h"
19 #include <luci_interpreter/Interpreter.h>
21 #include <dio_hdf5/HDF5Importer.h>
23 using namespace mpqsolver::bisection;
25 using Shape = std::vector<loco::Dimension>;
32 template <typename NodeT> size_t get_tensor_size(const NodeT *node)
34 uint32_t tensor_size = loco::size(node->dtype());
35 for (uint32_t i = 0; i < node->rank(); ++i)
36 tensor_size *= node->dim(i).value();
40 WholeOutput compute_outputs(const luci::Module *module, const std::string &h5file)
42 dio::hdf5::HDF5Importer importer{h5file};
43 importer.importGroup("value");
45 bool is_raw_data = importer.isRawData();
47 const auto num_records = importer.numData();
49 throw std::runtime_error("The input data file does not contain any record.");
50 const auto input_nodes = loco::input_nodes(module->graph());
51 const auto num_inputs = input_nodes.size();
53 WholeOutput dataset_output;
55 // Create interpreter.
56 luci_interpreter::Interpreter interpreter(module);
57 for (int32_t record_idx = 0; record_idx < num_records; record_idx++)
59 if (num_inputs != static_cast<uint32_t>(importer.numInputs(record_idx)))
60 throw std::runtime_error("Wrong number of inputs.");
61 for (uint32_t input_idx = 0; input_idx < num_inputs; input_idx++)
63 const auto *input_node = loco::must_cast<const luci::CircleInput *>(input_nodes[input_idx]);
64 assert(input_node->index() == input_idx);
66 std::vector<char> input_data(get_tensor_size(input_node));
72 importer.readTensor(record_idx, input_idx, &dtype, &shape, input_data.data());
76 // Skip type/shape check for raw data
77 importer.readTensor(record_idx, input_idx, input_data.data());
80 interpreter.writeInputTensor(input_node, input_data.data(), input_data.size());
83 interpreter.interpret();
88 const auto output_nodes = loco::output_nodes(module->graph());
89 for (size_t i = 0; i < module->graph()->outputs()->size(); i++)
91 const auto *output_node = loco::must_cast<const luci::CircleOutput *>(output_nodes[i]);
92 Buffer output_data(get_tensor_size(output_node));
93 interpreter.readOutputTensor(output_node, output_data.data(), output_data.size());
95 nn_output.push_back(output_data);
97 dataset_output.push_back(nn_output);
100 return dataset_output;
105 DatasetEvaluator::DatasetEvaluator(const luci::Module *ref_module, const std::string &h5file,
106 const ErrorMetric &metric)
107 : _ref_module(ref_module), _h5file(h5file), _metric(&metric)
109 _ref_output = compute_outputs(_ref_module, _h5file);
112 void DatasetEvaluator::validate(const luci::Module *trgt_fq_module) const
114 const auto output_nodes = loco::output_nodes(trgt_fq_module->graph());
115 for (size_t out_index = 0; out_index < output_nodes.size(); ++out_index)
117 const auto *output_node = loco::must_cast<const luci::CircleOutput *>(output_nodes[out_index]);
118 loco::DataType out_dtype = output_node->dtype();
119 if (out_dtype != loco::DataType::FLOAT32)
120 throw std::runtime_error("Unsupported output dtype " + output_node->name());
124 float DatasetEvaluator::evaluate(const luci::Module *trgt_fq_module) const
126 if (trgt_fq_module == nullptr)
127 throw std::runtime_error("Invalid target module");
129 if (_metric == nullptr)
130 throw std::runtime_error("Invalid metric");
132 validate(trgt_fq_module);
134 const WholeOutput &cur_output = compute_outputs(trgt_fq_module, _h5file);
135 float error = _metric->compute(_ref_output, cur_output);