Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / circle-mpqsolver / src / core / Evaluator.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Evaluator.h"
18
19 #include <luci_interpreter/Interpreter.h>
20
21 #include <dio_hdf5/HDF5Importer.h>
22
23 using namespace mpqsolver::core;
24
25 using Shape = std::vector<loco::Dimension>;
26
27 namespace
28 {
29
30 using namespace luci;
31
32 template <typename NodeT> size_t get_tensor_size(const NodeT *node)
33 {
34   uint32_t tensor_size = loco::size(node->dtype());
35   for (uint32_t i = 0; i < node->rank(); ++i)
36     tensor_size *= node->dim(i).value();
37   return tensor_size;
38 }
39
40 WholeOutput compute_outputs(const luci::Module *module, const std::string &h5file)
41 {
42   dio::hdf5::HDF5Importer importer{h5file};
43   importer.importGroup("value");
44
45   bool is_raw_data = importer.isRawData();
46
47   const auto num_records = importer.numData();
48   if (num_records == 0)
49     throw std::runtime_error("The input data file does not contain any record.");
50   const auto input_nodes = loco::input_nodes(module->graph());
51   const auto num_inputs = input_nodes.size();
52
53   WholeOutput dataset_output;
54
55   // Create interpreter.
56   luci_interpreter::Interpreter interpreter(module);
57   for (int32_t record_idx = 0; record_idx < num_records; record_idx++)
58   {
59     if (num_inputs != static_cast<uint32_t>(importer.numInputs(record_idx)))
60       throw std::runtime_error("Wrong number of inputs.");
61     for (uint32_t input_idx = 0; input_idx < num_inputs; input_idx++)
62     {
63       const auto *input_node = loco::must_cast<const luci::CircleInput *>(input_nodes[input_idx]);
64       assert(input_node->index() == input_idx);
65
66       std::vector<char> input_data(get_tensor_size(input_node));
67
68       if (!is_raw_data)
69       {
70         loco::DataType dtype;
71         Shape shape;
72         importer.readTensor(record_idx, input_idx, &dtype, &shape, input_data.data(),
73                             input_data.size());
74       }
75       else
76       {
77         // Skip type/shape check for raw data
78         importer.readTensor(record_idx, input_idx, input_data.data(), input_data.size());
79       }
80
81       interpreter.writeInputTensor(input_node, input_data.data(), input_data.size());
82     }
83
84     interpreter.interpret();
85
86     Output nn_output;
87
88     // Get output.
89     const auto output_nodes = loco::output_nodes(module->graph());
90     for (size_t i = 0; i < module->graph()->outputs()->size(); i++)
91     {
92       const auto *output_node = loco::must_cast<const luci::CircleOutput *>(output_nodes[i]);
93       Buffer output_data(get_tensor_size(output_node));
94       interpreter.readOutputTensor(output_node, output_data.data(), output_data.size());
95       // output
96       nn_output.push_back(output_data);
97     }
98     dataset_output.push_back(nn_output);
99   }
100
101   return dataset_output;
102 }
103
104 } // namespace
105
106 DatasetEvaluator::DatasetEvaluator(const luci::Module *ref_module, const std::string &h5file,
107                                    const ErrorMetric &metric)
108   : _ref_module(ref_module), _h5file(h5file), _metric(&metric)
109 {
110   _ref_output = compute_outputs(_ref_module, _h5file);
111 }
112
113 void DatasetEvaluator::validate(const luci::Module *trgt_fq_module) const
114 {
115   const auto output_nodes = loco::output_nodes(trgt_fq_module->graph());
116   for (size_t out_index = 0; out_index < output_nodes.size(); ++out_index)
117   {
118     const auto *output_node = loco::must_cast<const luci::CircleOutput *>(output_nodes[out_index]);
119     loco::DataType out_dtype = output_node->dtype();
120     if (out_dtype != loco::DataType::FLOAT32)
121       throw std::runtime_error("Unsupported output dtype " + output_node->name());
122   }
123 }
124
125 float DatasetEvaluator::evaluate(const luci::Module *trgt_fq_module) const
126 {
127   if (trgt_fq_module == nullptr)
128     throw std::runtime_error("Invalid target module");
129
130   if (_metric == nullptr)
131     throw std::runtime_error("Invalid metric");
132
133   validate(trgt_fq_module);
134
135   const WholeOutput &cur_output = compute_outputs(trgt_fq_module, _h5file);
136   float error = _metric->compute(_ref_output, cur_output);
137   return error;
138 }