Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / record-minmax / include / RecordMinMax.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __RECORD_MINMAX_H__
18 #define __RECORD_MINMAX_H__
19
20 #include <luci/IR/Module.h>
21 #include <luci_interpreter/Interpreter.h>
22
23 #include "MinMaxObserver.h"
24 #include "MinMaxComputer.h"
25
26 #include <memory>
27 #include <thread>
28
29 namespace record_minmax
30 {
31
32 using Buffer = std::vector<char>;
33 using Output = std::vector<Buffer>;
34 using WholeOutput = std::vector<Output>;
35
36 class RecordMinMax
37 {
38 public:
39   explicit RecordMinMax(uint32_t num_threads, std::unique_ptr<MinMaxComputer> &&minmax_computer)
40     : _threads_size(num_threads), _minmax_computer(std::move(minmax_computer))
41   {
42     assert(_threads_size > 0);
43     assert(_minmax_computer != nullptr);
44   }
45
46   ~RecordMinMax() = default;
47
48   void initialize(const std::string &input_model_path);
49
50   // TODO Refactor profile functions
51   void profileData(const std::string &input_data_path);
52
53   void profileDataInParallel(const std::string &input_data_path);
54
55   void profileRawData(const std::string &input_data_path);
56
57   void profileRawDataDirectory(const std::string &input_data_path);
58
59   void profileDataWithRandomInputs(void);
60
61   void saveModel(const std::string &output_model_path);
62
63 private:
64   luci_interpreter::Interpreter *getInterpreter() const { return _interpreters[0].get(); }
65
66   // Never return nullptr
67   MinMaxObserver *getObserver() const
68   {
69     assert(_observers.size() > 0); // FIX CALLER UNLESS
70     assert(_observers[0].get());   // FIX CALLER UNLESS
71     return _observers[0].get();
72   }
73
74   WholeOutput importH5Data(const std::string &input_data_path);
75
76   std::unique_ptr<luci::Module> _module;
77
78   // Multiple interpreters are used for parallel execution
79   std::vector<std::unique_ptr<luci_interpreter::Interpreter>> _interpreters;
80   std::vector<std::unique_ptr<MinMaxObserver>> _observers;
81
82   uint32_t _threads_size = 0;
83   std::unique_ptr<MinMaxComputer> _minmax_computer;
84 };
85
86 } // namespace record_minmax
87
88 #endif // __RECORD_MINMAX_H__