Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / calibration / command_line_processor.py
1 """
2 Copyright (C) 2018-2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 import tempfile
18
19 from ..accuracy_checker.accuracy_checker.config import ConfigReader
20 from ..accuracy_checker.accuracy_checker.launcher.dlsdk_launcher import DLSDKLauncher
21
22 from ..network import Network
23 from ..utils.path import Path
24 from ..utils.configuration_filter import ConfigurationFilter
25 from .calibration_configuration import CalibrationConfiguration
26 from .logging import info, default_logger
27 from .command_line_reader import CommandLineReader
28
29
30 class CommandLineProcessor:
31     """
32     Class for parsing user input config
33     """
34     @staticmethod
35     def process() -> CalibrationConfiguration:
36         args, unknown_args = CommandLineReader.parser().parse_known_args()
37         if unknown_args:
38             info("unknown command line arguments: {0}".format(unknown_args))
39
40         args.target_framework = "dlsdk"
41         args.aocl = None
42
43         merged_config = ConfigReader.merge(args)
44         updated_config = ConfigurationFilter.filter(merged_config, args.metric_name, args.metric_type, default_logger)
45
46         if len(updated_config['models']) > 1:
47             raise ValueError("too much models")
48
49         if len(updated_config['models'][0]['launchers']) > 1:
50             raise ValueError("too much launchers")
51
52         launcher = updated_config['models'][0]['launchers'][0]
53         if 'caffe_model' in launcher or 'tf_model' in launcher or 'mxnet_weights' in launcher:
54             if args.converted_models:
55                 tmp_directory = None
56             else:
57                 tmp_directory = tempfile.mkdtemp(".converted_models")
58                 launcher['mo_params']['output_dir'] = tmp_directory
59
60             if 'caffe_model' in launcher:
61                 framework = 'caffe'
62                 output_model = Path.get_model(
63                     str(launcher['caffe_model']),
64                     "_i8",
65                     str(args.output_dir) if args.output_dir else None)
66                 output_weights = Path.get_weights(
67                     str(launcher['caffe_weights']),
68                     "_i8",
69                     str(args.output_dir) if args.output_dir else None)
70             elif 'tf_model' in launcher:
71                 framework = 'tf'
72                 output_model = Path.get_model(
73                     str(launcher['tf_model']),
74                     "_i8",
75                     str(args.output_dir) if args.output_dir else None)
76                 output_weights = Path.get_weights(
77                     str(launcher['tf_model']),
78                     "_i8",
79                     str(args.output_dir) if args.output_dir else None)
80             elif 'mxnet_weights' in launcher:
81                 framework = 'mxnet'
82                 output_model = Path.get_model(
83                     str(launcher['mxnet_weights']),
84                     "_i8",
85                     str(args.output_dir) if args.output_dir else None)
86                 output_weights = Path.get_weights(
87                     str(launcher['mxnet_weights']),
88                     "_i8",
89                     str(args.output_dir) if args.output_dir else None)
90             else:
91                 raise ValueError("unknown model framework")
92
93             model, weights = DLSDKLauncher.convert_model(launcher, framework)
94             launcher['model'] = model
95             launcher['weights'] = weights
96
97             launcher.pop('caffe_model', None)
98             launcher.pop('caffe_weights', None)
99             launcher.pop('tf_model', None)
100             launcher.pop('mxnet_weights', None)
101         else:
102             model = launcher['model']
103             output_model = Path.get_model(str(model), "_i8", str(args.output_dir) if args.output_dir else None)
104             weights = launcher['weights']
105             output_weights = Path.get_weights(str(weights), "_i8", str(args.output_dir) if args.output_dir else None)
106             tmp_directory = None
107
108         batch_size = args.batch_size if args.batch_size else (launcher['batch'] if 'batch' in launcher else None)
109         if not batch_size:
110             with Network(str(launcher['model']), str(launcher['weights'])) as network:
111                 batch_size = network.ie_network.batch_size
112
113         if 'cpu_extensions' in launcher:
114             cpu_extension = DLSDKLauncher.get_cpu_extension(launcher['cpu_extensions'], args.cpu_extensions_mode)
115             launcher['cpu_extensions'] = cpu_extension
116         else:
117             cpu_extension = None
118
119         if not args.calibrate_fully_connected:
120             if args.ignore_layer_types is None:
121                 args.ignore_layer_types = []
122             args.ignore_layer_types.append("FullyConnected")
123
124         return CalibrationConfiguration(
125             config=updated_config,
126             precision=args.precision,
127             model=str(model),
128             weights=str(weights),
129             tmp_directory=tmp_directory,
130             output_model=output_model,
131             output_weights=output_weights,
132             cpu_extension=str(cpu_extension) if cpu_extension else None,
133             gpu_extension=str(launcher['gpu_extensions']) if 'gpu_extensions' in launcher else None,
134             device=launcher['device'],
135             batch_size=batch_size,
136             threshold=args.threshold,
137             ignore_layer_types=args.ignore_layer_types,
138             ignore_layer_types_path=args.ignore_layer_types_path,
139             ignore_layer_names=args.ignore_layer_names,
140             ignore_layer_names_path=args.ignore_layer_names_path,
141             benchmark_iterations_count=args.benchmark_iterations_count,
142             progress=(None if args.progress == 'None' else args.progress))