1 # Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
15 # Script that dumps dequantized FM
16 # NOTE This script runs on dalgona
21 from pathlib import Path
23 # Fake-quantized Op has the postfix of fq_postfix
24 # TODO Remove coupling with fake quantization codes
25 fq_postfix = '_FQ_Quantize_FQ_Dequantize'
28 # Return the original name before fake quantization
29 # Return None if name is not from fake quantization (Dequantize Op in original model)
30 # TODO Handle the case when the original node's name contains fq_postfix
31 def _name_before_fq(name):
32 if not name.endswith(fq_postfix):
35 return name[0:name.find(fq_postfix)]
38 # Dump fake-quantized model's intermediate FM data according to tensors.txt
48 # NOTE tensors.json has a dictionary {TENSOR_NAME -> TENSOR_ID}
49 class DumpFakeQuantFM:
50 def StartAnalysis(self, args):
51 self._dir = Path(args)
53 with open(self._dir / 'tensors.json') as f:
54 self._tname_to_tid = json.load(f)
57 def EndNetworkExecution(self, outputs: list):
60 # TODO Use DequantizePost when dalgona supports it
61 def DefaultOpPost(self, name, opcode, inputs, outputs):
62 if opcode == 'Dequantize':
63 for output in outputs:
66 orig_name = _name_before_fq(name)
67 if orig_name in self._tname_to_tid:
68 tid = self._tname_to_tid[orig_name]
69 data_path = self._dir / str(self._num_data)
70 data_path.mkdir(parents=False, exist_ok=True)
71 np.save(str(data_path / str(tid)), data)
72 # Save scales (scale is fixed, so saving once)
73 if orig_name not in self._scale_map:
74 assert len(inputs) == 1
75 assert 'quantparam' in inputs[0]
76 assert 'scale' in inputs[0]['quantparam']
77 assert len(inputs[0]['quantparam']['scale']) == 1
78 scale = inputs[0]['quantparam']['scale'][0]
79 self._scale_map[orig_name] = scale
81 def EndAnalysis(self):
82 # Dump saved scales into scales.txt
83 with open(self._dir / 'scales.txt', 'w') as f:
84 json.dump(self._scale_map, f)