3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
25 __slots__ = () # This prevents access via __dict__.
27 # (OPTION_NAME, HELP_MESSAGE)
28 ('O1', 'enable O1 optimization pass'),
29 ('convert_nchw_to_nhwc',
30 'Experimental: This will convert NCHW operators to NHWC under the assumption that input model is NCHW.'
32 ('expand_broadcast_const', 'expand broadcastable constant node inputs'),
33 ('nchw_to_nhwc_input_shape',
34 'convert the input shape of the model (argument for convert_nchw_to_nhwc)'),
35 ('nchw_to_nhwc_output_shape',
36 'convert the output shape of the model (argument for convert_nchw_to_nhwc)'),
37 ('fold_add_v2', 'fold AddV2 op with constant inputs'),
38 ('fold_cast', 'fold Cast op with constant input'),
39 ('fold_dequantize', 'fold Dequantize op'),
40 ('fold_dwconv', 'fold Depthwise Convolution op with constant inputs'),
41 ('fold_sparse_to_dense', 'fold SparseToDense op'),
42 ('forward_reshape_to_unaryop', 'Forward Reshape op'),
43 ('fuse_add_with_tconv', 'fuse Add op to Transposed'),
44 ('fuse_add_with_fully_connected', 'fuse Add op to FullyConnected op'),
45 ('fuse_batchnorm_with_conv', 'fuse BatchNorm op to Convolution op'),
46 ('fuse_batchnorm_with_dwconv', 'fuse BatchNorm op to Depthwise Convolution op'),
47 ('fuse_batchnorm_with_tconv', 'fuse BatchNorm op to Transposed Convolution op'),
48 ('fuse_bcq', 'apply Binary Coded Quantization'),
49 ('fuse_preactivation_batchnorm',
50 'fuse BatchNorm operators of pre-activations to Convolution op'),
51 ('fuse_mean_with_mean', 'fuse two consecutive Mean ops'),
52 ('fuse_transpose_with_mean',
53 'fuse Mean with a preceding Transpose under certain conditions'),
54 ('make_batchnorm_gamma_positive',
55 'make negative gamma of BatchNorm to a small positive value (1e-10).'
56 ' Note that this pass can change the execution result of the model.'
57 ' So, use it only when the impact is known to be acceptable.'),
58 ('fuse_activation_function', 'fuse Activation function to a preceding operator'),
59 ('fuse_instnorm', 'fuse ops to InstanceNorm operator'),
60 ('replace_cw_mul_add_with_depthwise_conv',
61 'replace channel-wise Mul/Add with DepthwiseConv2D'),
62 ('remove_fakequant', 'remove FakeQuant ops'),
63 ('remove_quantdequant', 'remove Quantize-Dequantize sequence'),
64 ('remove_redundant_reshape', 'fuse or remove subsequent Reshape ops'),
65 ('remove_redundant_transpose', 'fuse or remove subsequent Transpose ops'),
66 ('remove_unnecessary_reshape', 'remove unnecessary reshape ops'),
67 ('remove_unnecessary_slice', 'remove unnecessary slice ops'),
68 ('remove_unnecessary_strided_slice', 'remove unnecessary strided slice ops'),
69 ('remove_unnecessary_split', 'remove unnecessary split ops'),
70 ('resolve_customop_add', 'convert Custom(Add) op to Add op'),
71 ('resolve_customop_batchmatmul',
72 'convert Custom(BatchMatmul) op to BatchMatmul op'),
73 ('resolve_customop_matmul', 'convert Custom(Matmul) op to Matmul op'),
74 ('resolve_customop_max_pool_with_argmax',
75 'convert Custom(MaxPoolWithArgmax) to net of builtin operators'),
76 ('shuffle_weight_to_16x1float32',
77 'convert weight format of FullyConnected op to SHUFFLED16x1FLOAT32.'
78 ' Note that it only converts weights whose row is a multiple of 16'),
79 ('substitute_pack_to_reshape', 'convert single input Pack op to Reshape op'),
80 ('substitute_padv2_to_pad', 'convert certain condition PadV2 to Pad'),
81 ('substitute_splitv_to_split', 'convert certain condition SplitV to Split'),
82 ('substitute_squeeze_to_reshape', 'convert certain condition Squeeze to Reshape'),
83 ('substitute_strided_slice_to_reshape',
84 'convert certain condition StridedSlice to Reshape'),
85 ('substitute_transpose_to_reshape',
86 'convert certain condition Transpose to Reshape'),
87 ('transform_min_max_to_relu6', 'transform Minimum-Maximum pattern to Relu6 op'),
88 ('transform_min_relu_to_relu6', 'transform Minimum(6)-Relu pattern to Relu6 op'))
91 _CONSTANT = _CONSTANT()
94 def _add_default_arg(parser):
100 help='show program\'s version number and exit')
107 help='output additional information to stdout or stderr')
110 parser.add_argument('-C', '--config', type=str, help='run with configuation file')
111 # section name that you want to run in configuration file
112 parser.add_argument('-S', '--section', type=str, help=argparse.SUPPRESS)
115 def is_accumulated_arg(arg, driver):
116 if driver == "one-quantize":
117 if arg == "tensor_name" or arg == "scale" or arg == "zero_point":
123 def _is_valid_attr(args, attr):
124 return hasattr(args, attr) and getattr(args, attr)
127 def _parse_cfg(args, driver_name):
128 """parse configuration file. If the option is directly given to the command line,
129 the option is processed prior to the configuration file."""
130 if _is_valid_attr(args, 'config'):
131 config = configparser.ConfigParser()
132 config.optionxform = str
133 config.read(args.config)
134 # if section is given, verify given section
135 if _is_valid_attr(args, 'section'):
136 if not config.has_section(args.section):
137 raise AssertionError('configuration file must have \'' + driver_name +
139 for key in config[args.section]:
140 if is_accumulated_arg(key, driver_name):
141 if not _is_valid_attr(args, key):
142 setattr(args, key, [config[args.section][key]])
144 getattr(args, key).append(config[args.section][key])
146 if not _is_valid_attr(args, key):
147 setattr(args, key, config[args.section][key])
148 # if section is not given, section name is same with its driver name
150 if not config.has_section(driver_name):
151 raise AssertionError('configuration file must have \'' + driver_name +
153 secton_to_run = driver_name
154 for key in config[secton_to_run]:
155 if is_accumulated_arg(key, driver_name):
156 if not _is_valid_attr(args, key):
157 setattr(args, key, [config[secton_to_run][key]])
159 getattr(args, key).append(config[secton_to_run][key])
161 if not _is_valid_attr(args, key):
162 setattr(args, key, config[secton_to_run][key])
165 def _make_tf2tfliteV2_cmd(args, driver_path, input_path, output_path):
166 """make a command for running tf2tfliteV2.py"""
167 cmd = [sys.executable, os.path.expanduser(driver_path)]
169 if _is_valid_attr(args, 'verbose'):
170 cmd.append('--verbose')
172 if _is_valid_attr(args, 'model_format_cmd'):
173 cmd.append(getattr(args, 'model_format_cmd'))
174 elif _is_valid_attr(args, 'model_format'):
175 cmd.append('--' + getattr(args, 'model_format'))
177 cmd.append('--graph_def') # default value
179 if _is_valid_attr(args, 'converter_version_cmd'):
180 cmd.append(getattr(args, 'converter_version_cmd'))
181 elif _is_valid_attr(args, 'converter_version'):
182 cmd.append('--' + getattr(args, 'converter_version'))
184 cmd.append('--v1') # default value
186 if _is_valid_attr(args, 'input_path'):
187 cmd.append('--input_path')
188 cmd.append(os.path.expanduser(input_path))
190 if _is_valid_attr(args, 'output_path'):
191 cmd.append('--output_path')
192 cmd.append(os.path.expanduser(output_path))
194 if _is_valid_attr(args, 'input_arrays'):
195 cmd.append('--input_arrays')
196 cmd.append(getattr(args, 'input_arrays'))
198 if _is_valid_attr(args, 'input_shapes'):
199 cmd.append('--input_shapes')
200 cmd.append(getattr(args, 'input_shapes'))
202 if _is_valid_attr(args, 'output_arrays'):
203 cmd.append('--output_arrays')
204 cmd.append(getattr(args, 'output_arrays'))
209 def _make_tflite2circle_cmd(driver_path, input_path, output_path):
210 """make a command for running tflite2circle"""
211 cmd = [driver_path, input_path, output_path]
212 return [os.path.expanduser(c) for c in cmd]
215 def _make_circle2circle_cmd(args, driver_path, input_path, output_path):
216 """make a command for running circle2circle"""
217 cmd = [os.path.expanduser(c) for c in [driver_path, input_path, output_path]]
219 if _is_valid_attr(args, 'generate_profile_data'):
220 cmd.append('--generate_profile_data')
221 # optimization pass(only true/false options)
222 # TODO support options whose number of arguments is more than zero
223 for opt in _CONSTANT.OPTIMIZATION_OPTS:
224 if _is_valid_attr(args, opt[0]):
226 if type(getattr(args, opt[0])) is bool:
227 cmd.append('--' + opt[0])
229 This condition check is for config file interface, usually would be
231 but user can write as follows while development
233 instead of removing SomeOption option
235 if type(getattr(args, opt[0])) is str and not getattr(
236 args, opt[0]).lower() in ['false', '0', 'n']:
237 cmd.append('--' + opt[0])
242 def _print_version_and_exit(file_path):
243 """print version of the file located in the file_path"""
244 script_path = os.path.realpath(file_path)
245 dir_path = os.path.dirname(script_path)
246 script_name = os.path.splitext(os.path.basename(script_path))[0]
248 subprocess.call([os.path.join(dir_path, 'one-version'), script_name])
252 def _safemain(main, mainpath):
253 """execute given method and print with program name for all uncaught exceptions"""
256 except Exception as e:
257 prog_name = os.path.basename(mainpath)
258 print(f"{prog_name}: {type(e).__name__}: " + str(e), file=sys.stderr)
262 def _run(cmd, err_prefix=None, logfile=None):
263 """Execute command in subprocess
266 cmd: command to be executed in subprocess
267 err_prefix: prefix to be put before every stderr lines
268 logfile: file stream to which both of stdout and stderr lines will be written
270 with subprocess.Popen(
271 cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1) as p:
273 inputs = set([p.stdout, p.stderr])
275 readable, _, _ = select.select(inputs, [], [])
286 line = f"{err_prefix}: ".encode() + line
287 out.buffer.write(line)
291 if p.returncode != 0:
292 sys.exit(p.returncode)