Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / one-cmds / one-import-bcq
1 #!/usr/bin/env bash
2 ''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"                  # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python                                       # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@"                                     # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
6 ''''exit 255                                                                            # '''
7
8 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
9 #
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
13 #
14 #    http://www.apache.org/licenses/LICENSE-2.0
15 #
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
21
22 import argparse
23 import os
24 import subprocess
25 import sys
26 import tempfile
27
28 import utils as _utils
29 import generate_bcq_output_arrays as _bcq_info_gen
30
31
32 def _get_parser():
33     parser = argparse.ArgumentParser(
34         description='command line tool to convert TensorFlow with BCQ to circle')
35
36     _utils._add_default_arg(parser)
37
38     ## tf2tfliteV2 arguments
39     tf2tfliteV2_group = parser.add_argument_group('converter arguments')
40
41     # converter version
42     converter_version = tf2tfliteV2_group.add_mutually_exclusive_group()
43     converter_version.add_argument(
44         '--v1',
45         action='store_const',
46         dest='converter_version_cmd',
47         const='--v1',
48         help='use TensorFlow Lite Converter 1.x')
49     converter_version.add_argument(
50         '--v2',
51         action='store_const',
52         dest='converter_version_cmd',
53         const='--v2',
54         help='use TensorFlow Lite Converter 2.x')
55
56     parser.add_argument('--converter_version', type=str, help=argparse.SUPPRESS)
57
58     # input and output path.
59     tf2tfliteV2_group.add_argument(
60         '-i', '--input_path', type=str, help='full filepath of the input file')
61     tf2tfliteV2_group.add_argument(
62         '-o', '--output_path', type=str, help='full filepath of the output file')
63
64     # input and output arrays.
65     tf2tfliteV2_group.add_argument(
66         '-I',
67         '--input_arrays',
68         type=str,
69         help='names of the input arrays, comma-separated')
70     tf2tfliteV2_group.add_argument(
71         '-s',
72         '--input_shapes',
73         type=str,
74         help=
75         'shapes corresponding to --input_arrays, colon-separated (ex:"1,4,4,3:1,20,20,3")'
76     )
77     tf2tfliteV2_group.add_argument(
78         '-O',
79         '--output_arrays',
80         type=str,
81         help='names of the output arrays, comma-separated')
82
83     return parser
84
85
86 def _verify_arg(parser, args):
87     """verify given arguments"""
88     # check if required arguments is given
89     missing = []
90     if not _utils._is_valid_attr(args, 'input_path'):
91         missing.append('-i/--input_path')
92     if not _utils._is_valid_attr(args, 'output_path'):
93         missing.append('-o/--output_path')
94     if len(missing):
95         parser.error('the following arguments are required: ' + ' '.join(missing))
96
97
98 def _parse_arg(parser):
99     args = parser.parse_args()
100     # print version
101     if args.version:
102         _utils._print_version_and_exit(__file__)
103
104     return args
105
106
107 def _make_generate_bcq_metadata_cmd(args, driver_path, output_path):
108     """make a command for running generate_bcq_metadata"""
109     cmd = [sys.executable, driver_path]
110     # input_path
111     if _utils._is_valid_attr(args, 'input_path'):
112         cmd.append('--input_path')
113         cmd.append(os.path.expanduser(getattr(args, 'input_path')))
114     # output_path
115     if _utils._is_valid_attr(args, 'output_path'):
116         cmd.append('--output_path')
117         cmd.append(os.path.expanduser(output_path))
118     # output_arrays
119     if _utils._is_valid_attr(args, 'output_arrays'):
120         cmd.append('--output_arrays')
121         cmd.append(getattr(args, 'output_arrays'))
122
123     return cmd
124
125
126 def _convert(args):
127     # get file path to log
128     dir_path = os.path.dirname(os.path.realpath(__file__))
129     logfile_path = os.path.realpath(args.output_path) + '.log'
130
131     with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
132         # make a command to generate BCQ information metadata
133         generate_bcq_metadata_path = os.path.join(dir_path, 'generate_bcq_metadata.py')
134         generate_bcq_metadata_output_path = os.path.join(
135             tmpdir,
136             os.path.splitext(os.path.basename(args.input_path))[0] + '_withmeta.pb')
137         generate_bcq_metadata_cmd = _make_generate_bcq_metadata_cmd(
138             args, generate_bcq_metadata_path, generate_bcq_metadata_output_path)
139
140         f.write((' '.join(generate_bcq_metadata_cmd) + '\n').encode())
141
142         # generate BCQ information metadata
143         with subprocess.Popen(
144                 generate_bcq_metadata_cmd,
145                 stdout=subprocess.PIPE,
146                 stderr=subprocess.STDOUT,
147                 bufsize=1) as p:
148             for line in p.stdout:
149                 sys.stdout.buffer.write(line)
150                 f.write(line)
151         if p.returncode != 0:
152             sys.exit(p.returncode)
153
154         # get output_arrays with BCQ
155         bcq_output_arrays = _bcq_info_gen.get_bcq_output_arrays(
156             generate_bcq_metadata_output_path, getattr(args, 'output_arrays'))
157
158         # make a command to convert from tf with BCQ to tflite
159         tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
160         tf2tfliteV2_output_path = os.path.join(
161             tmpdir,
162             os.path.splitext(
163                 os.path.basename(generate_bcq_metadata_output_path))[0]) + '.tflite'
164         tf2tfliteV2_cmd = _utils._make_tf2tfliteV2_cmd(args, tf2tfliteV2_path,
165                                                        generate_bcq_metadata_output_path,
166                                                        tf2tfliteV2_output_path)
167         try:
168             output_arrays_idx = tf2tfliteV2_cmd.index('--output_arrays')
169             tf2tfliteV2_cmd[output_arrays_idx + 1] = ','.join(bcq_output_arrays)
170         except ValueError:
171             pass
172
173         f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
174
175         # convert tf to tflite
176         with subprocess.Popen(
177                 tf2tfliteV2_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
178                 bufsize=1) as p:
179             for line in p.stdout:
180                 sys.stdout.buffer.write(line)
181                 f.write(line)
182         if p.returncode != 0:
183             sys.exit(p.returncode)
184
185         # make a command to convert from tflite to circle
186         tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
187         tflite2circle_cmd = _utils._make_tflite2circle_cmd(tflite2circle_path,
188                                                            tf2tfliteV2_output_path,
189                                                            getattr(args, 'output_path'))
190
191         f.write((' '.join(tflite2circle_cmd) + '\n').encode())
192
193         # convert tflite to circle
194         with subprocess.Popen(
195                 tflite2circle_cmd,
196                 stdout=subprocess.PIPE,
197                 stderr=subprocess.STDOUT,
198                 bufsize=1) as p:
199             for line in p.stdout:
200                 sys.stdout.buffer.write(line)
201                 f.write(line)
202         if p.returncode != 0:
203             sys.exit(p.returncode)
204
205
206 def main():
207     # parse arguments
208     parser = _get_parser()
209     args = _parse_arg(parser)
210
211     # parse configuration file
212     _utils._parse_cfg(args, 'one-import-bcq')
213
214     # verify arguments
215     _verify_arg(parser, args)
216
217     # _convert
218     _convert(args)
219
220
221 if __name__ == '__main__':
222     main()