2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "Evaluator.h"
19 #include <luci_interpreter/Interpreter.h>
21 #include <dio_hdf5/HDF5Importer.h>
23 using namespace mpqsolver::core;
25 using Shape = std::vector<loco::Dimension>;
32 template <typename NodeT> size_t get_tensor_size(const NodeT *node)
34 uint32_t tensor_size = loco::size(node->dtype());
35 for (uint32_t i = 0; i < node->rank(); ++i)
36 tensor_size *= node->dim(i).value();
40 WholeOutput compute_outputs(const luci::Module *module, const std::string &h5file)
42 dio::hdf5::HDF5Importer importer{h5file};
43 importer.importGroup("value");
45 bool is_raw_data = importer.isRawData();
47 const auto num_records = importer.numData();
49 throw std::runtime_error("The input data file does not contain any record.");
50 const auto input_nodes = loco::input_nodes(module->graph());
51 const auto num_inputs = input_nodes.size();
53 WholeOutput dataset_output;
55 // Create interpreter.
56 luci_interpreter::Interpreter interpreter(module);
57 for (int32_t record_idx = 0; record_idx < num_records; record_idx++)
59 if (num_inputs != static_cast<uint32_t>(importer.numInputs(record_idx)))
60 throw std::runtime_error("Wrong number of inputs.");
61 for (uint32_t input_idx = 0; input_idx < num_inputs; input_idx++)
63 const auto *input_node = loco::must_cast<const luci::CircleInput *>(input_nodes[input_idx]);
64 assert(input_node->index() == input_idx);
66 std::vector<char> input_data(get_tensor_size(input_node));
72 importer.readTensor(record_idx, input_idx, &dtype, &shape, input_data.data(),
77 // Skip type/shape check for raw data
78 importer.readTensor(record_idx, input_idx, input_data.data(), input_data.size());
81 interpreter.writeInputTensor(input_node, input_data.data(), input_data.size());
84 interpreter.interpret();
89 const auto output_nodes = loco::output_nodes(module->graph());
90 for (size_t i = 0; i < module->graph()->outputs()->size(); i++)
92 const auto *output_node = loco::must_cast<const luci::CircleOutput *>(output_nodes[i]);
93 Buffer output_data(get_tensor_size(output_node));
94 interpreter.readOutputTensor(output_node, output_data.data(), output_data.size());
96 nn_output.push_back(output_data);
98 dataset_output.push_back(nn_output);
101 return dataset_output;
106 DatasetEvaluator::DatasetEvaluator(const luci::Module *ref_module, const std::string &h5file,
107 const ErrorMetric &metric)
108 : _ref_module(ref_module), _h5file(h5file), _metric(&metric)
110 _ref_output = compute_outputs(_ref_module, _h5file);
113 void DatasetEvaluator::validate(const luci::Module *trgt_fq_module) const
115 const auto output_nodes = loco::output_nodes(trgt_fq_module->graph());
116 for (size_t out_index = 0; out_index < output_nodes.size(); ++out_index)
118 const auto *output_node = loco::must_cast<const luci::CircleOutput *>(output_nodes[out_index]);
119 loco::DataType out_dtype = output_node->dtype();
120 if (out_dtype != loco::DataType::FLOAT32)
121 throw std::runtime_error("Unsupported output dtype " + output_node->name());
125 float DatasetEvaluator::evaluate(const luci::Module *trgt_fq_module) const
127 if (trgt_fq_module == nullptr)
128 throw std::runtime_error("Invalid target module");
130 if (_metric == nullptr)
131 throw std::runtime_error("Invalid metric");
133 validate(trgt_fq_module);
135 const WholeOutput &cur_output = compute_outputs(trgt_fq_module, _h5file);
136 float error = _metric->compute(_ref_output, cur_output);