Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / circle-mpqsolver / src / core / DumpingHooks.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __MPQSOLVER_DUMPING_HOOKS_H__
18 #define __MPQSOLVER_DUMPING_HOOKS_H__
19
20 #include <luci/IR/Module.h>
21
22 #include <core/Quantizer.h>
23 #include <core/SolverHooks.h>
24 #include <core/Dumper.h>
25
26 #include <string>
27
28 namespace mpqsolver
29 {
30 namespace core
31 {
32
33 /**
34  * @brief DumpingHooks is intended to save intermediate results
35  */
36 class DumpingHooks final : public QuantizerHook, public SolverHooks
37 {
38 public:
39   /**
40    * @brief DumpingHooks constructor
41    * @param save_path directory where all intermediate data will be saved
42    */
43   DumpingHooks(const std::string &save_path);
44
45   /**
46    * @brief called on successfull quantization
47    */
48   virtual void on_quantized(luci::Module *module) const override;
49
50   /**
51    * @brief called on the start of iterative search
52    */
53   virtual void on_begin_solver(const std::string &model_path, float q8error,
54                                float q16error) override;
55
56   /**
57    * @brief called on the start of current iteration
58    */
59   virtual void on_begin_iteration() override;
60
61   /**
62    * @brief called at the end of current iteration
63    */
64   virtual void on_end_iteration(const LayerParams &layers, const std::string &def_dtype,
65                                 float error) const override;
66
67   /**
68    * @brief called at the end of iterative search
69    */
70   virtual void on_end_solver(const LayerParams &layers, const std::string &def_dtype,
71                              float qerror) override;
72
73 protected:
74   std::string _model_path;
75   std::string _save_path;
76   Dumper _dumper;
77   uint32_t _num_of_iterations = 0;
78   bool _in_iterations = false;
79 };
80
81 } // namespace core
82 } // namespace mpqsolver
83
84 #endif //__MPQSOLVER_DUMPING_HOOKS_H__