2 Copyright (C) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from ..accuracy_checker.accuracy_checker.config import ConfigReader
20 from ..accuracy_checker.accuracy_checker.launcher.dlsdk_launcher import DLSDKLauncher
22 from ..network import Network
23 from ..utils.path import Path
24 from ..utils.configuration_filter import ConfigurationFilter
25 from .calibration_configuration import CalibrationConfiguration
26 from .logging import info, default_logger
27 from .command_line_reader import CommandLineReader
30 class CommandLineProcessor:
32 Class for parsing user input config
35 def process() -> CalibrationConfiguration:
36 args, unknown_args = CommandLineReader.parser().parse_known_args()
38 info("unknown command line arguments: {0}".format(unknown_args))
40 args.target_framework = "dlsdk"
43 merged_config = ConfigReader.merge(args)
44 updated_config = ConfigurationFilter.filter(merged_config, args.metric_name, args.metric_type, default_logger)
46 if len(updated_config['models']) > 1:
47 raise ValueError("too much models")
49 if len(updated_config['models'][0]['launchers']) > 1:
50 raise ValueError("too much launchers")
52 launcher = updated_config['models'][0]['launchers'][0]
53 if 'caffe_model' in launcher or 'tf_model' in launcher or 'mxnet_weights' in launcher:
54 if args.converted_models:
57 tmp_directory = tempfile.mkdtemp(".converted_models")
58 launcher['mo_params']['output_dir'] = tmp_directory
60 if 'caffe_model' in launcher:
62 output_model = Path.get_model(
63 str(launcher['caffe_model']),
65 str(args.output_dir) if args.output_dir else None)
66 output_weights = Path.get_weights(
67 str(launcher['caffe_weights']),
69 str(args.output_dir) if args.output_dir else None)
70 elif 'tf_model' in launcher:
72 output_model = Path.get_model(
73 str(launcher['tf_model']),
75 str(args.output_dir) if args.output_dir else None)
76 output_weights = Path.get_weights(
77 str(launcher['tf_model']),
79 str(args.output_dir) if args.output_dir else None)
80 elif 'mxnet_weights' in launcher:
82 output_model = Path.get_model(
83 str(launcher['mxnet_weights']),
85 str(args.output_dir) if args.output_dir else None)
86 output_weights = Path.get_weights(
87 str(launcher['mxnet_weights']),
89 str(args.output_dir) if args.output_dir else None)
91 raise ValueError("unknown model framework")
93 model, weights = DLSDKLauncher.convert_model(launcher, framework)
94 launcher['model'] = model
95 launcher['weights'] = weights
97 launcher.pop('caffe_model', None)
98 launcher.pop('caffe_weights', None)
99 launcher.pop('tf_model', None)
100 launcher.pop('mxnet_weights', None)
102 model = launcher['model']
103 output_model = Path.get_model(str(model), "_i8", str(args.output_dir) if args.output_dir else None)
104 weights = launcher['weights']
105 output_weights = Path.get_weights(str(weights), "_i8", str(args.output_dir) if args.output_dir else None)
108 batch_size = args.batch_size if args.batch_size else (launcher['batch'] if 'batch' in launcher else None)
110 with Network(str(launcher['model']), str(launcher['weights'])) as network:
111 batch_size = network.ie_network.batch_size
113 if 'cpu_extensions' in launcher:
114 cpu_extension = DLSDKLauncher.get_cpu_extension(launcher['cpu_extensions'], args.cpu_extensions_mode)
115 launcher['cpu_extensions'] = cpu_extension
119 if not args.calibrate_fully_connected:
120 if args.ignore_layer_types is None:
121 args.ignore_layer_types = []
122 args.ignore_layer_types.append("FullyConnected")
124 return CalibrationConfiguration(
125 config=updated_config,
126 precision=args.precision,
128 weights=str(weights),
129 tmp_directory=tmp_directory,
130 output_model=output_model,
131 output_weights=output_weights,
132 cpu_extension=str(cpu_extension) if cpu_extension else None,
133 gpu_extension=str(launcher['gpu_extensions']) if 'gpu_extensions' in launcher else None,
134 device=launcher['device'],
135 batch_size=batch_size,
136 threshold=args.threshold,
137 ignore_layer_types=args.ignore_layer_types,
138 ignore_layer_types_path=args.ignore_layer_types_path,
139 ignore_layer_names=args.ignore_layer_names,
140 ignore_layer_names_path=args.ignore_layer_names_path,
141 benchmark_iterations_count=args.benchmark_iterations_count,
142 progress=(None if args.progress == 'None' else args.progress))