Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compiler / luci-value-test / tester / src / EvalTester.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <luci/Importer.h>
18 #include <luci_interpreter/Interpreter.h>
19 #include <luci/CircleExporter.h>
20 #include <luci/CircleFileExpContract.h>
21
22 #include <cstdlib>
23 #include <fstream>
24 #include <iostream>
25 #include <vector>
26 #include <map>
27 #include <string>
28 #include <random>
29
30 namespace
31 {
32
33 void readDataFromFile(const std::string &filename, char *data, size_t data_size)
34 {
35   std::ifstream fs(filename, std::ifstream::binary);
36   if (fs.fail())
37     throw std::runtime_error("Cannot open file \"" + filename + "\".\n");
38   if (fs.read(data, data_size).fail())
39     throw std::runtime_error("Failed to read data from file \"" + filename + "\".\n");
40 }
41
42 void writeDataToFile(const std::string &filename, const char *data, size_t data_size)
43 {
44   std::ofstream fs(filename, std::ofstream::binary);
45   if (fs.fail())
46     throw std::runtime_error("Cannot open file \"" + filename + "\".\n");
47   if (fs.write(data, data_size).fail())
48   {
49     throw std::runtime_error("Failed to write data to file \"" + filename + "\".\n");
50   }
51 }
52
53 std::unique_ptr<luci::Module> importModel(const std::string &filename)
54 {
55   std::ifstream fs(filename, std::ifstream::binary);
56   if (fs.fail())
57   {
58     throw std::runtime_error("Cannot open model file \"" + filename + "\".\n");
59   }
60   std::vector<char> model_data((std::istreambuf_iterator<char>(fs)),
61                                std::istreambuf_iterator<char>());
62   return luci::Importer().importModule(circle::GetModel(model_data.data()));
63 }
64
65 template <typename NodeT> size_t getTensorSize(const NodeT *node)
66 {
67   uint32_t tensor_size = loco::size(node->dtype());
68   for (uint32_t i = 0; i < node->rank(); ++i)
69     tensor_size *= node->dim(i).value();
70   return tensor_size;
71 }
72
73 } // namespace
74
75 /*
76  * @brief EvalTester main
77  *
78  *        Driver for testing luci-inerpreter
79  *
80  */
81 int entry(int argc, char **argv)
82 {
83   if (argc != 5)
84   {
85     std::cerr
86         << "Usage: " << argv[0]
87         << " <path/to/circle/model> <num_inputs> <path/to/input/prefix> <path/to/output/file>\n";
88     return EXIT_FAILURE;
89   }
90
91   const char *filename = argv[1];
92   const int32_t num_inputs = atoi(argv[2]);
93   const char *input_prefix = argv[3];
94   const char *output_file = argv[4];
95   const std::string intermediate_filename = std::string(filename) + ".inter.circle";
96
97   // Load model from the file
98   std::unique_ptr<luci::Module> initial_module = importModel(filename);
99   if (initial_module == nullptr)
100   {
101     std::cerr << "ERROR: Failed to load '" << filename << "'" << std::endl;
102     return EXIT_FAILURE;
103   }
104
105   // Export to a Circle file
106   luci::CircleExporter exporter;
107
108   luci::CircleFileExpContract contract(initial_module.get(), intermediate_filename);
109
110   if (!exporter.invoke(&contract))
111   {
112     std::cerr << "ERROR: Failed to export '" << intermediate_filename << "'" << std::endl;
113     return EXIT_FAILURE;
114   }
115
116   // Import model again
117   std::unique_ptr<luci::Module> module = importModel(intermediate_filename);
118   if (module == nullptr)
119   {
120     std::cerr << "ERROR: Failed to load '" << intermediate_filename << "'" << std::endl;
121     return EXIT_FAILURE;
122   }
123
124   // Create interpreter.
125   luci_interpreter::Interpreter interpreter(module.get());
126
127   // Set input.
128   // Data for n'th input is read from ${input_prefix}n
129   // (ex: Add.circle.input0, Add.circle.input1 ..)
130   const auto input_nodes = loco::input_nodes(module->graph());
131   assert(num_inputs == input_nodes.size());
132   for (int32_t i = 0; i < num_inputs; i++)
133   {
134     const auto *input_node = loco::must_cast<const luci::CircleInput *>(input_nodes[i]);
135     std::vector<char> input_data(getTensorSize(input_node));
136     readDataFromFile(std::string(input_prefix) + std::to_string(i), input_data.data(),
137                      input_data.size());
138     interpreter.writeInputTensor(input_node, input_data.data(), input_data.size());
139   }
140
141   // Do inference.
142   interpreter.interpret();
143
144   // Get output.
145   const auto output_nodes = loco::output_nodes(module->graph());
146   for (int i = 0; i < module->graph()->outputs()->size(); i++)
147   {
148     const auto *output_node = loco::must_cast<const luci::CircleOutput *>(output_nodes[i]);
149     std::vector<char> output_data(getTensorSize(output_node));
150     interpreter.readOutputTensor(output_node, output_data.data(), output_data.size());
151
152     // Output data is written in ${output_file}
153     // (ex: Add.circle.output0)
154     // Output shape is written in ${output_file}.shape
155     // (ex: Add.circle.output0.shape)
156     writeDataToFile(std::string(output_file) + std::to_string(i), output_data.data(),
157                     output_data.size());
158     // In case of Tensor output is Scalar value.
159     // The output tensor with rank 0 is treated as a scalar with shape (1)
160     if (output_node->rank() == 0)
161     {
162       writeDataToFile(std::string(output_file) + std::to_string(i) + ".shape", "1", 1);
163     }
164     else
165     {
166       auto shape_str = std::to_string(output_node->dim(0).value());
167       for (int j = 1; j < output_node->rank(); j++)
168       {
169         shape_str += ",";
170         shape_str += std::to_string(output_node->dim(j).value());
171       }
172       writeDataToFile(std::string(output_file) + std::to_string(i) + ".shape", shape_str.c_str(),
173                       shape_str.size());
174     }
175   }
176   return EXIT_SUCCESS;
177 }