Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / one-cmds / one-quantize
1 #!/usr/bin/env bash
2 ''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python                                       # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@"                                     # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
6 ''''exit 255                                                                            # '''
7
8 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
9 #
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
13 #
14 #    http://www.apache.org/licenses/LICENSE-2.0
15 #
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
21
22 import argparse
23 import os
24 import sys
25 import tempfile
26 import json
27
28 import onelib.utils as oneutils
29 from onelib.Command import Command
30
31 # TODO Find better way to suppress trackback on error
32 sys.tracebacklimit = 0
33
34
35 def _get_parser():
36     parser = argparse.ArgumentParser(
37         description='command line tool to quantize circle model')
38
39     oneutils.add_default_arg(parser)
40
41     # input and output path.
42     parser.add_argument(
43         '-i', '--input_path', type=str, help='full filepath of the input circle model')
44     parser.add_argument(
45         '-d',
46         '--input_data',
47         type=str,
48         help=
49         'full filepath of the input data used for post-training quantization. if not specified, run with random input data.'
50     )
51     parser.add_argument(
52         '-f',
53         '--input_data_format',
54         type=str,
55         help=
56         'file format of input data. h5/hdf5 (default), list/filelist (a text file where a file path of input data is written in each line), or dir/directory (a directory where input data are saved)'
57     )
58     parser.add_argument(
59         '-o',
60         '--output_path',
61         type=str,
62         help='full filepath of the output quantized model')
63
64     # argument for profiling
65     parser.add_argument(
66         '-p',
67         '--generate_profile_data',
68         action='store_true',
69         help='generate profiling data')
70
71     # save intermediate file(s)
72     parser.add_argument(
73         '--save_intermediate',
74         action='store_true',
75         help='Save intermediate files to output folder')
76
77     ## arguments for quantization
78     quantization_group = parser.add_argument_group('arguments for quantization')
79
80     quantization_group.add_argument(
81         '--input_dtype',
82         type=str,
83         help=
84         'input model data type (supported: float32, default=float32). Deprecated (Use input_model_dtype)'
85     )
86     quantization_group.add_argument(
87         '--input_model_dtype',
88         type=str,
89         help='input model data type (supported: float32, default=float32)')
90     quantization_group.add_argument(
91         '--quantized_dtype',
92         type=str,
93         help='data type of output quantized model (supported: uint8, int16, default=uint8)'
94     )
95     quantization_group.add_argument(
96         '--granularity',
97         type=str,
98         help='quantization granularity (supported: layer, channel, default=layer)')
99     quantization_group.add_argument(
100         '--input_type',
101         type=str,
102         help=
103         'data type of inputs of quantized model (supported: uint8, int16, float32, default=quantized_dtype). QUANTIZE Op will be inserted at the beginning of the quantized model if input_type is different from quantized_dtype.'
104     )
105     quantization_group.add_argument(
106         '--output_type',
107         type=str,
108         help=
109         'data type of outputs of quantized model (supported: uint8, int16, float32, default=quantized_dtype). QUANTIZE Op will be inserted at the end of the quantized model if output_type is different from quantized_dtype.'
110     )
111     quantization_group.add_argument(
112         '--min_percentile',
113         type=str,
114         help=
115         'minimum percentile (0.0~100.0, default=1.0). Algorithm parameter for calibration. This is valid when calibration algorithm is percentile.'
116     )
117     quantization_group.add_argument(
118         '--max_percentile',
119         type=str,
120         help=
121         'maximum percentile (0.0~100.0, default=99.0). Algorithm parameter for calibration. This is valid when calibration algorithm is percentile.'
122     )
123     quantization_group.add_argument(
124         '--moving_avg_batch',
125         type=str,
126         help=
127         'batch size of moving average (default=16). This is valid when calibration algorithm is moving_average.'
128     )
129     quantization_group.add_argument(
130         '--moving_avg_const',
131         type=str,
132         help=
133         'hyperparameter (C) to compute moving average (default=0.1). Update equation: avg <- avg + C * (curr_batch_avg - avg). This is valid when calibration algorithm is moving_average.'
134     )
135     quantization_group.add_argument(
136         '--mode',
137         type=str,
138         help=
139         "calibration algorithm for post-training quantization (supported: percentile/moving_average, default=percentile). 'percentile' mode uses the n-th percentiles as min/max values. 'moving_average' mode records the moving average of min/max."
140     )
141     quantization_group.add_argument(
142         '--TF-style_maxpool',
143         action='store_true',
144         help=
145         "Force MaxPool Op to have the same input/output quantparams. NOTE: This option can degrade accuracy of some models.)"
146     )
147     quantization_group.add_argument(
148         '--quant_config', type=str, help="Path to the quantization configuration file.")
149     quantization_group.add_argument(
150         '--evaluate_result',
151         action='store_true',
152         help=
153         "Evaluate accuracy of quantized model. Run inference for both fp32 model and the quantized model, and compare the inference results."
154     )
155     quantization_group.add_argument(
156         '--test_data', type=str, help="Path to the test data used for evaluation.")
157     quantization_group.add_argument(
158         '--print_mae',
159         action='store_true',
160         help=
161         "Print MAE (Mean Absolute Error) of inference results between quantized model and fp32 model."
162     )
163     quantization_group.add_argument(
164         '--print_mape',
165         action='store_true',
166         help=
167         "Print MAPE (Mean Absolute Percentage Error) of inference results between quantized model and fp32 model."
168     )
169     quantization_group.add_argument(
170         '--print_mpeir',
171         action='store_true',
172         help=
173         "Print MPEIR (Mean Peak Error to Interval Ratio) of inference results between quantized model and fp32 model."
174     )
175     quantization_group.add_argument(
176         '--print_top1_match',
177         action='store_true',
178         help=
179         "Print Top-1 match ratio of inference results between quantized model and fp32 model."
180     )
181     quantization_group.add_argument(
182         '--print_top5_match',
183         action='store_true',
184         help=
185         "Print Top-5 match ratio of inference results between quantized model and fp32 model."
186     )
187     quantization_group.add_argument(
188         '--print_mse',
189         action='store_true',
190         help=
191         "Print MSE (Mean Squared Error) of inference results between quantized model and fp32 model."
192     )
193
194     # arguments for force_quantparam option
195     force_quantparam_group = parser.add_argument_group(
196         'arguments for force_quantparam option')
197
198     force_quantparam_group.add_argument(
199         '--force_quantparam',
200         action='store_true',
201         help=
202         'overwrite quantparam (scale, zero_point) to the specified tensor in the quantized model.'
203     )
204     force_quantparam_group.add_argument(
205         '--tensor_name', type=str, action='append', help='tensor name (string)')
206     force_quantparam_group.add_argument(
207         '--scale', type=float, action='append', help='scale (float)')
208     force_quantparam_group.add_argument(
209         '--zero_point', type=int, action='append', help='zero point (int)')
210
211     # arguments for copy_quantparam option
212     copy_quantparam_group = parser.add_argument_group(
213         'arguments for copy_quantparam option')
214
215     copy_quantparam_group.add_argument(
216         '--copy_quantparam',
217         action='store_true',
218         help='copy quantparam (scale, zero_point) of a tensor to another tensor.')
219     copy_quantparam_group.add_argument(
220         '--src_tensor_name', type=str, action='append', help='tensor name (string)')
221     copy_quantparam_group.add_argument(
222         '--dst_tensor_name', type=str, action='append', help='tensor name (string)')
223
224     # arguments for fake_quant option
225     fake_quant_group = parser.add_argument_group('arguments for fake_quantize option')
226
227     fake_quant_group.add_argument(
228         '--fake_quantize',
229         action='store_true',
230         help='convert quantized model to fake-quantized fp32 model.')
231
232     # arguments for requantize option
233     requantize_group = parser.add_argument_group('arguments for requantize option')
234
235     requantize_group.add_argument(
236         '--requantize',
237         action='store_true',
238         help='convert quantized model to another-typed quantized model (ex: int8 -> uin8).'
239     )
240
241     # arguments for ampq option
242     ampq_quant_group = parser.add_argument_group('arguments for ampq option')
243     # ampq
244     ampq_quant_group.add_argument(
245         '--ampq', action='store_true', help='quantize model using ampq solver.')
246
247     # ampq_qerror_ratio
248     ampq_quant_group.add_argument(
249         '--ampq_qerror_ratio', type=str, help='quantization error ratio ([0, 1])')
250
251     # ampq_algorithm
252     ampq_quant_group.add_argument(
253         '--ampq_algorithm', type=str, help='type of algorithm (bisection)')
254
255     ampq_quant_group.add_argument(
256         '--bisection_type', type=str, help="one of 'auto', 'i16_front', 'i16_back'")
257
258     # ampq_bisection_visq
259     ampq_quant_group.add_argument(
260         '--ampq_bisection_visq',
261         type=str,
262         help='.visq.json file path with quantization errors')
263
264     return parser
265
266
267 def _set_default_values(args):
268     if not oneutils.is_valid_attr(args,
269                                   'input_model_dtype') and not oneutils.is_valid_attr(
270                                       args, 'input_dtype'):
271         setattr(args, 'input_model_dtype', 'float32')
272     if not oneutils.is_valid_attr(args, 'quantized_dtype'):
273         setattr(args, 'quantized_dtype', 'uint8')
274         if oneutils.is_valid_attr(args, 'quant_config'):
275             # Get quantized_dtype from qconfig file
276             try:
277                 with open(getattr(args, 'quant_config')) as f:
278                     qconf = json.load(f)
279                     if 'default_quantization_dtype' in qconf:
280                         setattr(args, 'quantized_dtype',
281                                 qconf['default_quantization_dtype'])
282             except json.decoder.JSONDecodeError:
283                 print('Failed to decode ' + getattr(args, 'quant_config') +
284                       '. Please check it is a json file.')
285     if not oneutils.is_valid_attr(args, 'granularity'):
286         setattr(args, 'granularity', 'layer')
287         if oneutils.is_valid_attr(args, 'quant_config'):
288             # Get granularity from qconfig file
289             try:
290                 with open(getattr(args, 'quant_config')) as f:
291                     qconf = json.load(f)
292                     if 'default_granularity' in qconf:
293                         setattr(args, 'granularity', qconf['default_granularity'])
294             except json.decoder.JSONDecodeError:
295                 print('Failed to decode ' + getattr(args, 'quant_config') +
296                       '. Please check it is a json file.')
297     if not oneutils.is_valid_attr(args, 'mode'):
298         setattr(args, 'mode', 'percentile')
299     if not oneutils.is_valid_attr(args, 'min_percentile'):
300         setattr(args, 'min_percentile', '1.0')
301     if not oneutils.is_valid_attr(args, 'max_percentile'):
302         setattr(args, 'max_percentile', '99.0')
303     if not oneutils.is_valid_attr(args, 'moving_avg_batch'):
304         setattr(args, 'moving_avg_batch', '16')
305     if not oneutils.is_valid_attr(args, 'moving_avg_const'):
306         setattr(args, 'moving_avg_const', '0.1')
307     if not oneutils.is_valid_attr(args, 'ampq_algorithm'):
308         setattr(args, 'ampq_algorithm', 'bisection')
309     if not oneutils.is_valid_attr(args, 'bisection_type'):
310         setattr(args, 'bisection_type', 'auto')
311
312
313 def _verify_arg_pre(parser, args):
314     """verify given arguments before default values are set"""
315     # check if required arguments is given
316     missing = []
317     if oneutils.is_valid_attr(args, 'requantize'):
318         if not oneutils.is_valid_attr(args,
319                                       'input_model_dtype') and not oneutils.is_valid_attr(
320                                           args, 'input_dtype'):
321             missing.append('--input_model_dtype')
322         if not oneutils.is_valid_attr(args, 'quantized_dtype'):
323             missing.append('--quantized_dtype')
324     if len(missing):
325         parser.error('the following arguments are required: ' + ' '.join(missing))
326
327
328 def _verify_arg(parser, args):
329     """verify given arguments"""
330     # check if required arguments is given
331     missing = []
332     if not oneutils.is_valid_attr(args, 'input_path'):
333         missing.append('-i/--input_path')
334     if not oneutils.is_valid_attr(args, 'output_path'):
335         missing.append('-o/--output_path')
336     if oneutils.is_valid_attr(args, 'force_quantparam'):
337         if not oneutils.is_valid_attr(args, 'tensor_name'):
338             missing.append('--tensor_name')
339         if not oneutils.is_valid_attr(args, 'scale'):
340             missing.append('--scale')
341         if not oneutils.is_valid_attr(args, 'zero_point'):
342             missing.append('--zero_point')
343     if oneutils.is_valid_attr(args, 'copy_quantparam'):
344         if not oneutils.is_valid_attr(args, 'src_tensor_name'):
345             missing.append('--src_tensor_name')
346         if not oneutils.is_valid_attr(args, 'dst_tensor_name'):
347             missing.append('--dst_tensor_name')
348     if len(missing):
349         parser.error('the following arguments are required: ' + ' '.join(missing))
350     if oneutils.is_valid_attr(args, 'force_quantparam'):
351         tensors = getattr(args, 'tensor_name')
352         scales = getattr(args, 'scale')
353         zerops = getattr(args, 'zero_point')
354         if len(tensors) != len(scales) or len(tensors) != len(zerops):
355             parser.error(
356                 'The same number of tensor_name, scale, and zero_point should be given.')
357     if oneutils.is_valid_attr(args, 'copy_quantparam'):
358         src_tensors = getattr(args, 'src_tensor_name')
359         dst_tensors = getattr(args, 'dst_tensor_name')
360         if len(src_tensors) != len(dst_tensors):
361             parser.error(
362                 'The same number of src_tensor_name and dst_tensor_name should be given.')
363
364     # Check calibration parameters
365     if oneutils.is_valid_attr(args, 'mode'):
366         if getattr(args, 'mode') == 'percentile':
367             # Check dtype
368             try:
369                 min_percentile = float(getattr(args, 'min_percentile'))
370             except ValueError:
371                 parser.error('min_percentile must be float')
372             try:
373                 max_percentile = float(getattr(args, 'max_percentile'))
374             except ValueError:
375                 parser.error('max_percentile must be float')
376         elif getattr(args, 'mode') == 'moving_average':
377             # Check dtype
378             try:
379                 moving_avg_batch = int(getattr(args, 'moving_avg_batch'))
380             except ValueError:
381                 parser.error('moving_avg_batch must be integer')
382             try:
383                 moving_avg_const = float(getattr(args, 'moving_avg_const'))
384             except ValueError:
385                 parser.error('moving_avg_const must be float')
386         else:
387             parser.error('Unsupported mode')
388
389
390 def _parse_arg(parser):
391     args = parser.parse_args()
392     # print version
393     if args.version:
394         oneutils.print_version_and_exit(__file__)
395
396     return args
397
398
399 def _quantize(args):
400     if oneutils.is_valid_attr(args, 'ampq'):
401         _ampq_solve(args)
402         return
403
404     if oneutils.is_valid_attr(args, 'force_quantparam'):
405         # write quantization parameters
406         _write_qparam(args)
407         return
408
409     if oneutils.is_valid_attr(args, 'copy_quantparam'):
410         # copy quantization parameters
411         _copy_qparam(args)
412         return
413
414     if oneutils.is_valid_attr(args, 'fake_quantize'):
415         # fake-quantize model
416         _fake_quantize(args)
417         return
418
419     if oneutils.is_valid_attr(args, 'requantize'):
420         # requantize model
421         _requantize(args)
422         return
423
424     # get file path to log
425     dir_path = os.path.dirname(os.path.realpath(__file__))
426     logfile_path = os.path.realpath(args.output_path) + '.log'
427
428     with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
429         if oneutils.is_valid_attr(args, 'save_intermediate'):
430             tmpdir = os.path.dirname(logfile_path)
431         # get driver path
432         circle_quantizer_path = os.path.join(dir_path, 'circle-quantizer')
433         record_minmax_path = os.path.join(dir_path, 'record-minmax')
434
435         ## make a command to quantize and dequantize the weights of the model
436         circle_quantizer_cmd = [circle_quantizer_path]
437         # verbose
438         if oneutils.is_valid_attr(args, 'verbose'):
439             circle_quantizer_cmd.append('--verbose')
440         # quantize_dequantize_weights
441         circle_quantizer_cmd.append('--quantize_dequantize_weights')
442         # Use input_model_dtype if it exists. Use input_dtype otherwise.
443         if oneutils.is_valid_attr(args, 'input_model_dtype'):
444             circle_quantizer_cmd.append(getattr(args, 'input_model_dtype'))
445         elif oneutils.is_valid_attr(args, 'input_dtype'):
446             circle_quantizer_cmd.append(getattr(args, 'input_dtype'))
447         if oneutils.is_valid_attr(args, 'quantized_dtype'):
448             circle_quantizer_cmd.append(getattr(args, 'quantized_dtype'))
449         if oneutils.is_valid_attr(args, 'granularity'):
450             circle_quantizer_cmd.append(getattr(args, 'granularity'))
451         if oneutils.is_valid_attr(args, 'quant_config'):
452             # NOTE --config conflicts with --config option in onecc, so
453             # we use quant_config for one-quantize
454             circle_quantizer_cmd.append('--config')
455             circle_quantizer_cmd.append(getattr(args, 'quant_config'))
456         # input and output path
457         if oneutils.is_valid_attr(args, 'input_path'):
458             circle_quantizer_cmd.append(getattr(args, 'input_path'))
459         tmp_weights_fake_quant_path = os.path.join(
460             tmpdir,
461             os.path.splitext(os.path.basename(
462                 args.input_path))[0]) + '.weights_fake_quant.circle'
463         circle_quantizer_cmd.append(tmp_weights_fake_quant_path)
464         # profiling
465         if oneutils.is_valid_attr(args, 'generate_profile_data'):
466             circle_quantizer_cmd.append('--generate_profile_data')
467
468         f.write((' '.join(circle_quantizer_cmd) + '\n').encode())
469
470         # run circle-quantizer
471         oneutils.run(circle_quantizer_cmd, err_prefix="circle_quantizer", logfile=f)
472
473         tmp_minmax_recorded_path = os.path.join(
474             tmpdir,
475             os.path.splitext(os.path.basename(
476                 args.input_path))[0]) + '.minmax_recorded.circle'
477
478         ## make a command to record min-max value of each tensor while running the representative dataset
479         record_minmax_cmd = Command(record_minmax_path, args, f)
480         record_minmax_cmd.add_noarg_option_if_valid_arg('--verbose', 'verbose') \
481             .add_option_with_values('--input_model', [tmp_weights_fake_quant_path]) \
482             .add_option_with_values('--output_model', [tmp_minmax_recorded_path]) \
483             .add_option_with_valid_args('--input_data', ['input_data']) \
484             .add_option_with_valid_args('--input_data_format', ['input_data_format']) \
485             .add_option_with_valid_args('--min_percentile', ['min_percentile']) \
486             .add_option_with_valid_args('--max_percentile', ['max_percentile']) \
487             .add_option_with_valid_args('--moving_avg_batch', ['moving_avg_batch']) \
488             .add_option_with_valid_args('--moving_avg_const', ['moving_avg_const']) \
489             .add_option_with_valid_args('--mode', ['mode']) \
490             .add_noarg_option_if_valid_arg('--generate_profile_data', 'generate_profile_data') \
491             .run()
492
493         ## make a second command to quantize the model using the embedded information
494         circle_quantizer_cmd = [circle_quantizer_path]
495         # verbose
496         if oneutils.is_valid_attr(args, 'verbose'):
497             circle_quantizer_cmd.append('--verbose')
498         # quantize_dequantize_weights
499         circle_quantizer_cmd.append('--quantize_with_minmax')
500         # Use input_model_dtype if it exists. Use input_dtype otherwise.
501         if oneutils.is_valid_attr(args, 'input_model_dtype'):
502             circle_quantizer_cmd.append(getattr(args, 'input_model_dtype'))
503         elif oneutils.is_valid_attr(args, 'input_dtype'):
504             circle_quantizer_cmd.append(getattr(args, 'input_dtype'))
505         if oneutils.is_valid_attr(args, 'quantized_dtype'):
506             circle_quantizer_cmd.append(getattr(args, 'quantized_dtype'))
507         if oneutils.is_valid_attr(args, 'granularity'):
508             circle_quantizer_cmd.append(getattr(args, 'granularity'))
509         if oneutils.is_valid_attr(args, 'TF-style_maxpool'):
510             circle_quantizer_cmd.append('--TF-style_maxpool')
511         if oneutils.is_valid_attr(args, 'input_type'):
512             circle_quantizer_cmd.append('--input_type')
513             circle_quantizer_cmd.append(getattr(args, 'input_type'))
514         if oneutils.is_valid_attr(args, 'output_type'):
515             circle_quantizer_cmd.append('--output_type')
516             circle_quantizer_cmd.append(getattr(args, 'output_type'))
517         if oneutils.is_valid_attr(args, 'quant_config'):
518             # NOTE --config conflicts with --config option in onecc, so
519             # we use quant_config for one-quantize
520             circle_quantizer_cmd.append('--config')
521             circle_quantizer_cmd.append(getattr(args, 'quant_config'))
522         # input and output path
523         circle_quantizer_cmd.append(tmp_minmax_recorded_path)
524         if oneutils.is_valid_attr(args, 'output_path'):
525             circle_quantizer_cmd.append(getattr(args, 'output_path'))
526         # profiling
527         if oneutils.is_valid_attr(args, 'generate_profile_data'):
528             circle_quantizer_cmd.append('--generate_profile_data')
529
530         f.write((' '.join(circle_quantizer_cmd) + '\n').encode())
531
532         # run circle-quantizer
533         oneutils.run(circle_quantizer_cmd, err_prefix="circle_quantizer", logfile=f)
534
535         # evaluate
536         if oneutils.is_valid_attr(args, 'evaluate_result'):
537             circle_eval_diff_path = os.path.join(dir_path, 'circle-eval-diff')
538             quant_model = ""
539             if oneutils.is_valid_attr(args, 'output_path'):
540                 quant_model = getattr(args, 'output_path')
541             tmp_fake_quant_model = os.path.join(
542                 tmpdir,
543                 os.path.splitext(os.path.basename(
544                     args.input_path))[0]) + '.fake_quant.circle'
545
546             # do fake quantization
547             fake_quantize_cmd = Command(circle_quantizer_path, args, f)
548             fake_quantize_cmd.add_noarg_option_if_valid_arg('--verbose', 'verbose') \
549                 .add_option_with_values('--fake_quantize', [quant_model, tmp_fake_quant_model]) \
550                 .run()
551
552             # compare fake-quant model and fp32 model
553             circle_eval_diff_cmd = Command(circle_eval_diff_path, args, f)
554             circle_eval_diff_cmd.add_option_with_valid_args('--first_model', ['input_path']) \
555                 .add_option_with_values('--second_model', [tmp_fake_quant_model]) \
556                 .add_option_with_valid_args('--first_input_data', ['test_data']) \
557                 .add_option_with_valid_args('--second_input_data', ['test_data']) \
558                 .add_option_with_valid_args('--input_data_format', ['input_data_format']) \
559                 .add_noarg_option_if_valid_arg('--print_mae', 'print_mae') \
560                 .add_noarg_option_if_valid_arg('--print_mape', 'print_mape') \
561                 .add_noarg_option_if_valid_arg('--print_mpeir', 'print_mpeir') \
562                 .add_noarg_option_if_valid_arg('--print_top1_match', 'print_top1_match') \
563                 .add_noarg_option_if_valid_arg('--print_top5_match', 'print_top5_match') \
564                 .add_noarg_option_if_valid_arg('--print_mse', 'print_mse') \
565                 .run()
566
567
568 def _write_qparam(args):
569     # get file path to log
570     dir_path = os.path.dirname(os.path.realpath(__file__))
571     logfile_path = os.path.realpath(args.output_path) + '.log'
572
573     with open(logfile_path, 'wb') as f:
574         # get driver path
575         circle_quantizer_path = os.path.join(dir_path, 'circle-quantizer')
576
577         # make a command to write qparams to the tensors
578         circle_quantizer_cmd = [circle_quantizer_path]
579         # verbose
580         if oneutils.is_valid_attr(args, 'verbose'):
581             circle_quantizer_cmd.append('--verbose')
582         if oneutils.is_valid_attr(args, 'tensor_name'):
583             tensor_name = getattr(args, 'tensor_name')
584         if oneutils.is_valid_attr(args, 'scale'):
585             scale = getattr(args, 'scale')
586         if oneutils.is_valid_attr(args, 'zero_point'):
587             zero_point = getattr(args, 'zero_point')
588         for (t, s, zp) in zip(tensor_name, scale, zero_point):
589             circle_quantizer_cmd.append('--force_quantparam')
590             circle_quantizer_cmd.append(t)
591             circle_quantizer_cmd.append(str(s))
592             circle_quantizer_cmd.append(str(zp))
593         # input and output path
594         if oneutils.is_valid_attr(args, 'input_path'):
595             circle_quantizer_cmd.append(getattr(args, 'input_path'))
596         if oneutils.is_valid_attr(args, 'output_path'):
597             circle_quantizer_cmd.append(getattr(args, 'output_path'))
598
599         f.write((' '.join(circle_quantizer_cmd) + '\n').encode())
600
601         # run circle-quantizer
602         oneutils.run(circle_quantizer_cmd, err_prefix="circle_quantizer", logfile=f)
603
604
605 def _copy_qparam(args):
606     # get file path to log
607     dir_path = os.path.dirname(os.path.realpath(__file__))
608     logfile_path = os.path.realpath(args.output_path) + '.log'
609
610     with open(logfile_path, 'wb') as f:
611         # get driver path
612         circle_quantizer_path = os.path.join(dir_path, 'circle-quantizer')
613
614         # make a command to write qparams to the tensors
615         circle_quantizer_cmd = [circle_quantizer_path]
616         # verbose
617         if oneutils.is_valid_attr(args, 'verbose'):
618             circle_quantizer_cmd.append('--verbose')
619         if oneutils.is_valid_attr(args, 'src_tensor_name'):
620             src_tensor_name = getattr(args, 'src_tensor_name')
621         if oneutils.is_valid_attr(args, 'dst_tensor_name'):
622             dst_tensor_name = getattr(args, 'dst_tensor_name')
623         for (src, dst) in zip(src_tensor_name, dst_tensor_name):
624             circle_quantizer_cmd.append('--copy_quantparam')
625             circle_quantizer_cmd.append(src)
626             circle_quantizer_cmd.append(dst)
627         # input and output path
628         if oneutils.is_valid_attr(args, 'input_path'):
629             circle_quantizer_cmd.append(getattr(args, 'input_path'))
630         if oneutils.is_valid_attr(args, 'output_path'):
631             circle_quantizer_cmd.append(getattr(args, 'output_path'))
632
633         f.write((' '.join(circle_quantizer_cmd) + '\n').encode())
634
635         # run circle-quantizer
636         oneutils.run(circle_quantizer_cmd, err_prefix="circle_quantizer", logfile=f)
637
638
639 def _fake_quantize(args):
640     # get file path to log
641     dir_path = os.path.dirname(os.path.realpath(__file__))
642     logfile_path = os.path.realpath(args.output_path) + '.log'
643
644     with open(logfile_path, 'wb') as f:
645         # get driver path
646         circle_quantizer_path = os.path.join(dir_path, 'circle-quantizer')
647         q_model = getattr(args, 'input_path')
648         fq_model = getattr(args, 'output_path')
649
650         # do fake quantization
651         fake_quantize_cmd = Command(circle_quantizer_path, args, f)
652         fake_quantize_cmd.add_noarg_option_if_valid_arg('--verbose', 'verbose') \
653             .add_option_with_values('--fake_quantize', [q_model, fq_model]) \
654             .run()
655
656
657 def _ampq_solve(args):
658     # get file path to log
659     dir_path = os.path.dirname(os.path.realpath(__file__))
660     logfile_path = os.path.realpath(args.output_path) + '.log'
661
662     with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
663         if oneutils.is_valid_attr(args, 'save_intermediate'):
664             tmpdir = os.path.dirname(logfile_path)
665
666         # get driver path
667         record_minmax_path = os.path.join(dir_path, 'record-minmax')
668
669         tmp_minmax_recorded_path = os.path.join(
670             tmpdir,
671             os.path.splitext(os.path.basename(
672                 args.input_path))[0]) + '.minmax_recorded.circle'
673
674         ## make a command to record min-max value of each tensor while running the representative dataset
675         record_minmax_cmd = Command(record_minmax_path, args, f)
676         record_minmax_cmd.add_noarg_option_if_valid_arg('--verbose', 'verbose') \
677             .add_option_with_valid_args('--input_model', ['input_path']) \
678             .add_option_with_values('--output_model', [tmp_minmax_recorded_path]) \
679             .add_option_with_valid_args('--input_data', ['input_data']) \
680             .add_option_with_valid_args('--input_data_format', ['input_data_format']) \
681             .add_option_with_valid_args('--min_percentile', ['min_percentile']) \
682             .add_option_with_valid_args('--max_percentile', ['max_percentile']) \
683             .add_option_with_valid_args('--moving_avg_batch', ['moving_avg_batch']) \
684             .add_option_with_valid_args('--moving_avg_const', ['moving_avg_const']) \
685             .add_option_with_valid_args('--mode', ['mode']) \
686             .add_noarg_option_if_valid_arg('--generate_profile_data', 'generate_profile_data') \
687             .run()
688
689         # process visq if needed
690         visq_file = None
691         if oneutils.is_valid_attr(args, 'ampq_bisection_visq'):
692             visq_file = getattr(args, 'ampq_bisection_visq')
693
694         if (oneutils.is_valid_attr(args, 'ampq_algorithm')
695                 and oneutils.is_valid_attr(args, 'bisection_type')):
696             algorithm = getattr(args, 'ampq_algorithm')
697             bisection_type = getattr(args, 'bisection_type')
698             if algorithm == 'bisection' and bisection_type == 'auto' and visq_file is None:
699                 # algorithm needs bisection but no file in input configuration
700
701                 # to compute visq file we need q8 quantized model
702                 q8_file = os.path.join(
703                     tmpdir,
704                     os.path.splitext(os.path.basename(
705                         args.input_path))[0]) + '.visq.q8.circle'
706
707                 # get drievr path
708                 circle_quantizer_path = os.path.join(dir_path, 'circle-quantizer')
709                 circle_quantizer_cmd = [circle_quantizer_path]
710                 # verbose
711                 if oneutils.is_valid_attr(args, 'verbose'):
712                     circle_quantizer_cmd.append('--verbose')
713                 circle_quantizer_cmd.append('--quantize_with_minmax')
714                 circle_quantizer_cmd.append('float32')
715                 circle_quantizer_cmd.append('uint8')
716                 circle_quantizer_cmd.append('channel')
717
718                 if oneutils.is_valid_attr(args, 'TF-style_maxpool'):
719                     circle_quantizer_cmd.append('--TF-style_maxpool')
720
721                 circle_quantizer_cmd.extend(['--input_type', 'uint8'])
722                 circle_quantizer_cmd.extend(['--output_type', 'uint8'])
723
724                 # input and output paths
725                 circle_quantizer_cmd.append(tmp_minmax_recorded_path)
726                 circle_quantizer_cmd.append(q8_file)
727
728                 f.write((' '.join(circle_quantizer_cmd) + '\n').encode())
729
730                 # run circle-quantizer
731                 oneutils.run(
732                     circle_quantizer_cmd, err_prefix="circle_quantizer", logfile=f)
733
734                 # compute visq file
735                 visq_path = os.path.join(dir_path, 'visq')
736
737                 visq_file = os.path.join(
738                     tmpdir,
739                     os.path.splitext(os.path.basename(
740                         args.input_path))[0]) + '.tae.visq.json'
741
742                 visq_cmd = [visq_path]
743                 visq_cmd.extend(['--fp32_circle', getattr(args, 'input_path')])
744                 visq_cmd.extend(['--data', getattr(args, 'input_data')])
745                 visq_cmd.extend(['--q_circle', q8_file])
746                 visq_cmd.extend(['--tae_output', visq_file])
747                 visq_cmd.extend(['--batch_size', "1"])
748                 visq_cmd.append('--dump_dot_graph')
749                 f.write((' '.join(visq_cmd) + '\n').encode())
750
751                 # run visq
752                 oneutils.run(visq_cmd, err_prefix="visq", logfile=f)
753
754         # get driver path
755         circle_mpqsolver_path = os.path.join(dir_path, 'circle-mpqsolver')
756
757         # solve for Mixed Precision Quantization configuration
758         ampq_quantize_cmd = [circle_mpqsolver_path]
759
760         # data
761         if oneutils.is_valid_attr(args, 'input_data'):
762             ampq_quantize_cmd.extend(['--data', getattr(args, 'input_data')])
763
764         # data format
765         if oneutils.is_valid_attr(args, 'input_data_format'):
766             ampq_quantize_cmd.extend(
767                 ['--data_format', getattr(args, 'input_data_format')])
768
769         # qerror_ratio
770         if oneutils.is_valid_attr(args, 'ampq_qerror_ratio'):
771             ampq_quantize_cmd.extend(
772                 ['--qerror_ratio', getattr(args, 'ampq_qerror_ratio')])
773
774         # algorithm
775         if oneutils.is_valid_attr(args, 'ampq_algorithm'):
776             algorithm = getattr(args, 'ampq_algorithm')
777             if algorithm == 'bisection':
778                 if oneutils.is_valid_attr(args, 'bisection_type'):
779                     bisection_type = getattr(args, 'bisection_type')
780                     if bisection_type == 'auto':
781                         ampq_quantize_cmd.extend(['--bisection', 'auto'])
782                     elif bisection_type == 'i16_front':
783                         ampq_quantize_cmd.extend(['--bisection', 'true'])
784                     elif bisection_type == 'i16_back':
785                         ampq_quantize_cmd.extend(['--bisection', 'false'])
786
787         # recorded model as input
788         ampq_quantize_cmd.extend(['--input_model', tmp_minmax_recorded_path])
789
790         # input_dtype
791         if oneutils.is_valid_attr(args, 'input_type'):
792             ampq_quantize_cmd.extend(['--input_dtype', getattr(args, 'input_type')])
793
794         # output dtype
795         if oneutils.is_valid_attr(args, 'output_type'):
796             ampq_quantize_cmd.extend(['--output_dtype', getattr(args, 'output_type')])
797
798         # output model
799         if oneutils.is_valid_attr(args, 'output_path'):
800             ampq_quantize_cmd.extend(['--output_model', getattr(args, 'output_path')])
801
802         # visq_file
803         if not (visq_file is None):
804             ampq_quantize_cmd.extend(['--visq_file', visq_file])
805
806         # save_intermediate
807         if oneutils.is_valid_attr(args, 'save_intermediate'):
808             intermediate_dir = os.path.dirname(logfile_path)
809             ampq_quantize_cmd.extend(['--save_intermediate', intermediate_dir])
810
811         if oneutils.is_valid_attr(args, 'verbose'):
812             ampq_quantize_cmd.append('--verbose')
813
814         f.write((' '.join(ampq_quantize_cmd) + '\n').encode())
815
816         # run ampq
817         oneutils.run(ampq_quantize_cmd, err_prefix="circle_mpqsolver", logfile=f)
818
819
820 def _requantize(args):
821     # get file path to log
822     dir_path = os.path.dirname(os.path.realpath(__file__))
823     logfile_path = os.path.realpath(args.output_path) + '.log'
824
825     with open(logfile_path, 'wb') as f:
826         # get driver path
827         circle_quantizer_path = os.path.join(dir_path, 'circle-quantizer')
828
829         ## make a command to quantize and dequantize the weights of the model
830         circle_quantizer_cmd = [circle_quantizer_path]
831         # verbose
832         if oneutils.is_valid_attr(args, 'verbose'):
833             circle_quantizer_cmd.append('--verbose')
834         # requantize
835         circle_quantizer_cmd.append('--requantize')
836         # Use input_model_dtype if it exists. Use input_dtype otherwise.
837         if oneutils.is_valid_attr(args, 'input_model_dtype'):
838             circle_quantizer_cmd.append(getattr(args, 'input_model_dtype'))
839         elif oneutils.is_valid_attr(args, 'input_dtype'):
840             circle_quantizer_cmd.append(getattr(args, 'input_dtype'))
841         if oneutils.is_valid_attr(args, 'quantized_dtype'):
842             circle_quantizer_cmd.append(getattr(args, 'quantized_dtype'))
843         # input and output path
844         if oneutils.is_valid_attr(args, 'input_path'):
845             circle_quantizer_cmd.append(getattr(args, 'input_path'))
846         if oneutils.is_valid_attr(args, 'output_path'):
847             circle_quantizer_cmd.append(getattr(args, 'output_path'))
848
849         f.write((' '.join(circle_quantizer_cmd) + '\n').encode())
850
851         # run circle-quantizer
852         oneutils.run(circle_quantizer_cmd, err_prefix="circle_quantizer", logfile=f)
853
854
855 def main():
856     # parse arguments
857     parser = _get_parser()
858     args = _parse_arg(parser)
859
860     # parse configuration file
861     oneutils.parse_cfg(args.config, 'one-quantize', args)
862
863     # verify arguments before default value setting
864     _verify_arg_pre(parser, args)
865
866     # set default values
867     _set_default_values(args)
868
869     # verify arguments
870     _verify_arg(parser, args)
871
872     # quantize
873     _quantize(args)
874
875
876 if __name__ == '__main__':
877     oneutils.safemain(main, __file__)