[Dist/Debian] Prepare for GCC >= 10
[platform/upstream/dldt.git] / tools / network.py
1 """
2 Copyright (C) 2018-2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 import os
18 import tempfile
19 import shutil
20 import ntpath
21 from pathlib import Path as std_path
22
23 import openvino.inference_engine as ie
24 from .utils.path import Path
25
26
27 class Network:
28     @staticmethod
29     def reload(model_path: str, statistics = None, quantization_levels: dict = None, batch_size: int = None):
30         tmp_model_dir = None
31         try:
32             with Network(model_path) as network:
33                 if statistics:
34                     network.set_statistics(statistics)
35                 if quantization_levels:
36                     network.set_quantization_levels(quantization_levels)
37
38                 tmp_model_dir = tempfile.mkdtemp(".model")
39                 tmp_model_path = os.path.join(tmp_model_dir, ntpath.basename(model_path))
40                 network.serialize(tmp_model_path)
41
42             network = Network(tmp_model_path)
43             Network.reshape(network.ie_network, batch_size)
44             return network
45         finally:
46             if tmp_model_dir:
47                 shutil.rmtree(tmp_model_dir)
48
49     @staticmethod
50     def serialize_tmp_model(model_path: str, statistics = None, quantization_levels: dict = None):
51         try:
52             with Network(model_path) as network:
53                 if statistics:
54                     network.set_statistics(statistics)
55                 if quantization_levels:
56                     network.set_quantization_levels(quantization_levels)
57
58                 tmp_model_dir = tempfile.mkdtemp(".model")
59                 tmp_model_path = os.path.join(tmp_model_dir, ntpath.basename(model_path))
60                 network.serialize(tmp_model_path)
61             return tmp_model_path
62         except:
63             print('Could not serialize temporary IR')
64             raise
65
66     @staticmethod
67     def rm_tmp_location(file_path):
68         if file_path:
69             pdir = std_path(file_path).parent
70             shutil.rmtree(str(pdir))
71
72     def __init__(self, model_path: str, weights_path: str=None):
73         if model_path is None:
74             raise ValueError("model_path is None")
75
76         self._model_path = model_path
77         self._weights_path = weights_path if weights_path else Path.get_weights(model_path)
78         self._ie_network = None
79
80     def __enter__(self):
81         return self
82
83     def __exit__(self, type, value, traceback):
84         self.release()
85
86     def release(self):
87         if self._ie_network:
88             del self._ie_network
89             self._ie_network = None
90
91     @staticmethod
92     def reshape(ie_network: ie.IENetwork, batch_size: int) -> ie.IENetwork:
93         if batch_size and batch_size != ie_network.batch_size:
94             new_shapes = {}
95             for input_layer_name, input_layer in ie_network.inputs.items():
96                 layout = input_layer.layout
97                 if layout == 'C':
98                     new_shape = (input_layer.shape[0],)
99                 elif layout == 'NC':
100                     new_shape = (batch_size, input_layer.shape[1])
101                 else:
102                     raise ValueError("not supported layout '{}'".format(layout))                    
103                 new_shapes[input_layer_name] = new_shape
104             ie_network.reshape(new_shapes)
105         return ie_network
106
107     @property
108     def model_path(self) -> str:
109         return self._model_path
110
111     @property
112     def weights_path(self) -> str:
113         return self._weights_path
114
115     @property
116     def ie_network(self) -> ie.IENetwork:
117         if not self._ie_network:
118             self._ie_network = ie.IENetwork(self._model_path, self._weights_path)
119         return self._ie_network
120
121     def set_quantization_levels(self, quantization_level: dict):
122         for layer_name, value in quantization_level.items():
123             params = self.ie_network.layers[layer_name].params
124             params["quantization_level"] = value
125             self.ie_network.layers[layer_name].params = params
126
127     def set_statistics(self, statistics: dict):
128         network_stats = {}
129         for layer_name, node_statistic in statistics.items():
130             network_stats[layer_name] = ie.LayerStats(min=tuple(node_statistic.min_outputs),
131                                                       max=tuple(node_statistic.max_outputs))
132         self.ie_network.stats.update(network_stats)
133
134     def serialize(self, model_path: str, weights_path: str=None):
135         self.ie_network.serialize(model_path, weights_path if weights_path else Path.get_weights(model_path))