Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / record-minmax / include / MinMaxObserver.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __RECORD_MINMAX_MINMAXOBSERVER_H__
18 #define __RECORD_MINMAX_MINMAXOBSERVER_H__
19
20 #include <luci_interpreter/Interpreter.h>
21 #include <luci_interpreter/core/Tensor.h>
22
23 #include "MinMaxVectors.h"
24
25 #include <vector>
26 #include <unordered_map>
27
28 namespace record_minmax
29 {
30
31 class MinMaxMap
32 {
33 public:
34   // Record min/max of node
35   void recordMinMax(const luci::CircleNode *node, float min, float max)
36   {
37     MinMaxVectors &vectors = _minmax_map[node];
38     vectors.min_vector.push_back(min);
39     vectors.max_vector.push_back(max);
40   }
41
42   void appendMinMaxVector(const luci::CircleNode *node, const MinMaxVectors &minmax_vector)
43   {
44     MinMaxVectors &vectors = _minmax_map[node];
45     vectors.min_vector.insert(vectors.min_vector.end(), minmax_vector.min_vector.begin(),
46                               minmax_vector.min_vector.end());
47     vectors.max_vector.insert(vectors.max_vector.end(), minmax_vector.max_vector.begin(),
48                               minmax_vector.max_vector.end());
49   }
50
51   const std::unordered_map<const luci::CircleNode *, MinMaxVectors> *getMap() const
52   {
53     return &_minmax_map;
54   }
55
56 private:
57   std::unordered_map<const luci::CircleNode *, MinMaxVectors> _minmax_map;
58 };
59
60 class MinMaxObserver : public luci_interpreter::ExecutionObserver
61 {
62 public:
63   MinMaxObserver()
64   {
65     // Do nothing
66   }
67
68   void postTensorWrite(const luci::CircleNode *node,
69                        const luci_interpreter::Tensor *tensor) override;
70
71   // Never return nullptr
72   const MinMaxMap *minMaxData() { return &_minmax_data; }
73
74 private:
75   MinMaxMap _minmax_data;
76 };
77
78 } // namespace record_minmax
79
80 #endif // __RECORD_MINMAX_MINMAXOBSERVER_H__