024842824da3523aeed11f44d7fc4edeb3c32194
[platform/core/ml/nnfw.git] / compiler / visq / visqlib / DumpFakeQuantFM.py
1 # Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #    http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 # Script that dumps dequantized FM
16 # NOTE This script runs on dalgona
17
18 import numpy as np
19
20 from pathlib import Path
21 from Util import to_filename
22
23 # Fake-quantized Op has the postfix of fq_postfix
24 # TODO Remove coupling with fake quantization codes
25 fq_postfix = '_FQ_Quantize_FQ_Dequantize'
26
27
28 # Return the original name before fake quantization
29 # Return None if name is not from fake quantization (Dequantize Op in original model)
30 # TODO Handle the case when the original node's name contains fq_postfix
31 def _name_before_fq(name):
32     if not name.endswith(fq_postfix):
33         return None
34
35     return name[0:name.find(fq_postfix)]
36
37
38 # Dump fake-quantized model's intermediate FM data according to tensors.txt
39 #
40 # Before
41 # self._dir/
42 #  tensors.txt
43 #
44 # After
45 # self._dir/
46 #  tensors.txt
47 #  <TENSOR_NAME>.npy
48 # NOTE TENSOR_NAME is transformed by to_filename
49 class DumpFakeQuantFM:
50     def StartAnalysis(self, args):
51         self._dir = Path(args)
52         self._num_data = 0
53         with open(self._dir / 'tensors.txt') as f:
54             self._target_tensors = set([line.rstrip() for line in f])
55
56     def EndNetworkExecution(self, outputs: list):
57         self._num_data += 1
58
59     # TODO Use DequantizePost when dalgona supports it
60     def DefaultOpPost(self, name, opcode, inputs, output):
61         if opcode == 'Dequantize':
62             orig_name = _name_before_fq(name)
63             if orig_name in self._target_tensors:
64                 data_path = self._dir / str(self._num_data)
65                 data_path.mkdir(parents=False, exist_ok=True)
66                 np.save(str(data_path / to_filename(orig_name)), output['data'])