2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __RECORD_MINMAX_H__
18 #define __RECORD_MINMAX_H__
20 #include <luci/IR/Module.h>
21 #include <luci_interpreter/Interpreter.h>
23 #include "MinMaxObserver.h"
28 namespace record_minmax
31 using Buffer = std::vector<char>;
32 using Output = std::vector<Buffer>;
33 using WholeOutput = std::vector<Output>;
38 explicit RecordMinMax(uint32_t num_threads) : _threads_size(num_threads)
40 assert(_threads_size > 0);
43 ~RecordMinMax() = default;
45 void initialize(const std::string &input_model_path);
47 void profileData(const std::string &mode, const std::string &input_data_path,
48 float min_percentile, float max_percentile);
50 void profileDataInParallel(const std::string &mode, const std::string &input_data_path,
51 float min_percentile, float max_percentile);
53 void profileRawData(const std::string &mode, const std::string &input_data_path,
54 float min_percentile, float max_percentile);
56 void profileRawDataDirectory(const std::string &mode, const std::string &input_data_path,
57 float min_percentile, float max_percentile);
59 void profileDataWithRandomInputs(const std::string &mode, float min_percentile,
60 float max_percentile);
62 void saveModel(const std::string &output_model_path);
65 luci_interpreter::Interpreter *getInterpreter() const { return _interpreters[0].get(); }
66 MinMaxObserver *getObserver() const { return _observers[0].get(); }
68 WholeOutput importH5Data(const std::string &input_data_path);
70 std::unique_ptr<luci::Module> _module;
72 // Multiple interpreters are used for parallel execution
73 std::vector<std::unique_ptr<luci_interpreter::Interpreter>> _interpreters;
74 std::vector<std::unique_ptr<MinMaxObserver>> _observers;
76 uint32_t _threads_size = 0;
79 } // namespace record_minmax
81 #endif // __RECORD_MINMAX_H__