Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / calibration / calibration_configuration.py
1 import shutil
2 from ..utils.network_info import NetworkInfo
3
4
5 class CalibrationConfiguration:
6     """
7     Class for parsing input config
8     """
9     def __init__(
10         self,
11         config: str,
12         precision: str,
13         model: str,
14         weights: str,
15         tmp_directory: str,
16         output_model: str,
17         output_weights: str,
18         cpu_extension: str,
19         gpu_extension: str,
20         device: str,
21         batch_size: int,
22         threshold: float,
23         ignore_layer_types: list,
24         ignore_layer_types_path: str,
25         ignore_layer_names: list,
26         ignore_layer_names_path: str,
27         benchmark_iterations_count: int,
28         progress: str):
29
30         self._config = config
31         self._precision = precision.upper()
32         self._model = model
33         self._weights = weights
34         self._tmp_directory = tmp_directory
35         self._output_model = output_model
36         self._output_weights = output_weights
37         self._cpu_extension = cpu_extension
38         self._gpu_extension = gpu_extension
39         self._device = device
40         self._batch_size = batch_size
41         self._threshold = threshold
42         self._ignore_layer_types = ignore_layer_types
43         self._ignore_layer_types_path = ignore_layer_types_path
44         self._ignore_layer_names = ignore_layer_names
45         self._ignore_layer_names_path = ignore_layer_names_path
46         self._benchmark_iterations_count = benchmark_iterations_count
47         self._progress = progress
48
49     def __enter__(self):
50         return self
51
52     def __exit__(self, type, value, traceback):
53         self.release()
54
55     def release(self):
56         if self.tmp_directory:
57             shutil.rmtree(self.tmp_directory)
58             self._tmp_directory = None
59
60     @property
61     def config(self) -> list:
62         return self._config
63
64     @property
65     def precision(self) -> str:
66         return self._precision
67
68     @property
69     def model(self) -> str:
70         return self._model
71
72     @property
73     def weights(self) -> str:
74         return self._weights
75
76     @property
77     def tmp_directory(self) -> str:
78         return self._tmp_directory
79
80     @property
81     def output_model(self) -> str:
82         return self._output_model
83
84     @property
85     def output_weights(self) -> str:
86         return self._output_weights
87
88     @property
89     def cpu_extension(self) -> str:
90         return self._cpu_extension
91
92     @property
93     def gpu_extension(self) -> str:
94         return self._gpu_extension
95
96     @property
97     def device(self) -> str:
98         return self._device
99
100     @property
101     def batch_size(self) -> int:
102         return self._batch_size
103
104     @property
105     def threshold(self) -> int:
106         return self._threshold
107
108     @property
109     def ignore_layer_types(self):
110         return self._ignore_layer_types
111
112     @property
113     def ignore_layer_types_path(self) -> str:
114         return self._ignore_layer_types_path
115
116     @property
117     def ignore_layer_names(self):
118         return self._ignore_layer_names
119
120     @property
121     def ignore_layer_names_path(self) -> str:
122         return self._ignore_layer_names_path
123
124     @property
125     def benchmark_iterations_count(self) -> int:
126         return self._benchmark_iterations_count
127
128     @property
129     def progress(self) -> str:
130         return self._progress
131
132
133 class CalibrationConfigurationHelper:
134     @staticmethod
135     def read_ignore_layer_names(configuration: CalibrationConfiguration):
136         ignore_layer_types = configuration.ignore_layer_types
137
138         if configuration.ignore_layer_types_path:
139             ignore_layer_types_file = open(configuration.ignore_layer_types_path, 'r')
140             ignore_layer_types_from_file = [line.strip() for line in ignore_layer_types_file.readlines()]
141             ignore_layer_types.extend(ignore_layer_types_from_file)
142
143         ignore_layer_names = NetworkInfo(configuration.model).get_layer_names(layer_types=ignore_layer_types)
144
145         if configuration.ignore_layer_names_path:
146             ignore_layer_names_file = open(configuration.ignore_layer_names_path, 'r')
147             ignore_layer_names_from_file = [line.strip() for line in ignore_layer_names_file.readlines()]
148             ignore_layer_names.extend(ignore_layer_names_from_file)
149
150         return ignore_layer_names