2a7b88f1fd4fab61a05b897372dfc3365f8fb074
[platform/core/ml/nnfw.git] / compiler / record-minmax / include / RecordMinMax.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __RECORD_MINMAX_H__
18 #define __RECORD_MINMAX_H__
19
20 #include <luci/IR/Module.h>
21 #include <luci_interpreter/Interpreter.h>
22
23 #include "MinMaxObserver.h"
24
25 #include <memory>
26 #include <thread>
27
28 namespace record_minmax
29 {
30
31 using Buffer = std::vector<char>;
32 using Output = std::vector<Buffer>;
33 using WholeOutput = std::vector<Output>;
34
35 class RecordMinMax
36 {
37 public:
38   explicit RecordMinMax(uint32_t num_threads) : _threads_size(num_threads)
39   {
40     assert(_threads_size > 0);
41   }
42
43   ~RecordMinMax() = default;
44
45   void initialize(const std::string &input_model_path);
46
47   void profileData(const std::string &mode, const std::string &input_data_path,
48                    float min_percentile, float max_percentile);
49
50   void profileDataInParallel(const std::string &mode, const std::string &input_data_path,
51                              float min_percentile, float max_percentile);
52
53   void profileRawData(const std::string &mode, const std::string &input_data_path,
54                       float min_percentile, float max_percentile);
55
56   void profileRawDataDirectory(const std::string &mode, const std::string &input_data_path,
57                                float min_percentile, float max_percentile);
58
59   void profileDataWithRandomInputs(const std::string &mode, float min_percentile,
60                                    float max_percentile);
61
62   void saveModel(const std::string &output_model_path);
63
64 private:
65   luci_interpreter::Interpreter *getInterpreter() const { return _interpreters[0].get(); }
66   MinMaxObserver *getObserver() const { return _observers[0].get(); }
67
68   WholeOutput importH5Data(const std::string &input_data_path);
69
70   std::unique_ptr<luci::Module> _module;
71
72   // Multiple interpreters are used for parallel execution
73   std::vector<std::unique_ptr<luci_interpreter::Interpreter>> _interpreters;
74   std::vector<std::unique_ptr<MinMaxObserver>> _observers;
75
76   uint32_t _threads_size = 0;
77 };
78
79 } // namespace record_minmax
80
81 #endif // __RECORD_MINMAX_H__