94d46a39c0348db5ed44840660bad4722e49aadc
[platform/core/ml/nnfw.git] / compiler / circle-mpqsolver / src / bisection / Evaluator.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Evaluator.h"
18
19 #include <luci_interpreter/Interpreter.h>
20
21 #include <dio_hdf5/HDF5Importer.h>
22
23 using namespace mpqsolver::bisection;
24
25 using Shape = std::vector<loco::Dimension>;
26
27 namespace
28 {
29
30 using namespace luci;
31
32 template <typename NodeT> size_t get_tensor_size(const NodeT *node)
33 {
34   uint32_t tensor_size = loco::size(node->dtype());
35   for (uint32_t i = 0; i < node->rank(); ++i)
36     tensor_size *= node->dim(i).value();
37   return tensor_size;
38 }
39
40 WholeOutput compute_outputs(const luci::Module *module, const std::string &h5file)
41 {
42   dio::hdf5::HDF5Importer importer{h5file};
43   importer.importGroup("value");
44
45   bool is_raw_data = importer.isRawData();
46
47   const auto num_records = importer.numData();
48   if (num_records == 0)
49     throw std::runtime_error("The input data file does not contain any record.");
50   const auto input_nodes = loco::input_nodes(module->graph());
51   const auto num_inputs = input_nodes.size();
52
53   WholeOutput dataset_output;
54
55   // Create interpreter.
56   luci_interpreter::Interpreter interpreter(module);
57   for (int32_t record_idx = 0; record_idx < num_records; record_idx++)
58   {
59     if (num_inputs != static_cast<uint32_t>(importer.numInputs(record_idx)))
60       throw std::runtime_error("Wrong number of inputs.");
61     for (uint32_t input_idx = 0; input_idx < num_inputs; input_idx++)
62     {
63       const auto *input_node = loco::must_cast<const luci::CircleInput *>(input_nodes[input_idx]);
64       assert(input_node->index() == input_idx);
65
66       std::vector<char> input_data(get_tensor_size(input_node));
67
68       if (!is_raw_data)
69       {
70         loco::DataType dtype;
71         Shape shape;
72         importer.readTensor(record_idx, input_idx, &dtype, &shape, input_data.data());
73       }
74       else
75       {
76         // Skip type/shape check for raw data
77         importer.readTensor(record_idx, input_idx, input_data.data());
78       }
79
80       interpreter.writeInputTensor(input_node, input_data.data(), input_data.size());
81     }
82
83     interpreter.interpret();
84
85     Output nn_output;
86
87     // Get output.
88     const auto output_nodes = loco::output_nodes(module->graph());
89     for (size_t i = 0; i < module->graph()->outputs()->size(); i++)
90     {
91       const auto *output_node = loco::must_cast<const luci::CircleOutput *>(output_nodes[i]);
92       Buffer output_data(get_tensor_size(output_node));
93       interpreter.readOutputTensor(output_node, output_data.data(), output_data.size());
94       // output
95       nn_output.push_back(output_data);
96     }
97     dataset_output.push_back(nn_output);
98   }
99
100   return dataset_output;
101 }
102
103 } // namespace
104
105 DatasetEvaluator::DatasetEvaluator(const luci::Module *ref_module, const std::string &h5file,
106                                    const ErrorMetric &metric)
107   : _ref_module(ref_module), _h5file(h5file), _metric(&metric)
108 {
109   _ref_output = compute_outputs(_ref_module, _h5file);
110 }
111
112 void DatasetEvaluator::validate(const luci::Module *trgt_fq_module) const
113 {
114   const auto output_nodes = loco::output_nodes(trgt_fq_module->graph());
115   for (size_t out_index = 0; out_index < output_nodes.size(); ++out_index)
116   {
117     const auto *output_node = loco::must_cast<const luci::CircleOutput *>(output_nodes[out_index]);
118     loco::DataType out_dtype = output_node->dtype();
119     if (out_dtype != loco::DataType::FLOAT32)
120       throw std::runtime_error("Unsupported output dtype " + output_node->name());
121   }
122 }
123
124 float DatasetEvaluator::evaluate(const luci::Module *trgt_fq_module) const
125 {
126   if (trgt_fq_module == nullptr)
127     throw std::runtime_error("Invalid target module");
128
129   if (_metric == nullptr)
130     throw std::runtime_error("Invalid metric");
131
132   validate(trgt_fq_module);
133
134   const WholeOutput &cur_output = compute_outputs(trgt_fq_module, _h5file);
135   float error = _metric->compute(_ref_output, cur_output);
136   return error;
137 }