2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __RECORD_MINMAX_H__
18 #define __RECORD_MINMAX_H__
20 #include <luci/IR/Module.h>
21 #include <luci_interpreter/Interpreter.h>
23 #include "MinMaxObserver.h"
24 #include "MinMaxComputer.h"
29 namespace record_minmax
32 using Buffer = std::vector<char>;
33 using Output = std::vector<Buffer>;
34 using WholeOutput = std::vector<Output>;
39 explicit RecordMinMax(uint32_t num_threads, std::unique_ptr<MinMaxComputer> &&minmax_computer)
40 : _threads_size(num_threads), _minmax_computer(std::move(minmax_computer))
42 assert(_threads_size > 0);
43 assert(_minmax_computer != nullptr);
46 ~RecordMinMax() = default;
48 void initialize(const std::string &input_model_path);
50 // TODO Refactor profile functions
51 void profileData(const std::string &input_data_path);
53 void profileDataInParallel(const std::string &input_data_path);
55 void profileRawData(const std::string &input_data_path);
57 void profileRawDataDirectory(const std::string &input_data_path);
59 void profileDataWithRandomInputs(void);
61 void saveModel(const std::string &output_model_path);
64 luci_interpreter::Interpreter *getInterpreter() const { return _interpreters[0].get(); }
66 // Never return nullptr
67 MinMaxObserver *getObserver() const
69 assert(_observers.size() > 0); // FIX CALLER UNLESS
70 assert(_observers[0].get()); // FIX CALLER UNLESS
71 return _observers[0].get();
74 WholeOutput importH5Data(const std::string &input_data_path);
76 std::unique_ptr<luci::Module> _module;
78 // Multiple interpreters are used for parallel execution
79 std::vector<std::unique_ptr<luci_interpreter::Interpreter>> _interpreters;
80 std::vector<std::unique_ptr<MinMaxObserver>> _observers;
82 uint32_t _threads_size = 0;
83 std::unique_ptr<MinMaxComputer> _minmax_computer;
86 } // namespace record_minmax
88 #endif // __RECORD_MINMAX_H__