Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / network.py
1 """
2 Copyright (C) 2018-2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 import os
18 import tempfile
19 import shutil
20 import ntpath
21
22 import openvino.inference_engine as ie
23 from .utils.path import Path
24
25
26 class Network:
27     @staticmethod
28     def reload(model_path: str, statistics = None, quantization_levels: dict() = None, batch_size: int = None):
29         tmp_model_dir = None
30         try:
31             with Network(model_path) as network:
32                 if statistics:
33                     network.set_statistics(statistics)
34                 if quantization_levels:
35                     network.set_quantization_levels(quantization_levels)
36
37                 tmp_model_dir = tempfile.mkdtemp(".model")
38                 tmp_model_path = os.path.join(tmp_model_dir, ntpath.basename(model_path))
39                 network.serialize(tmp_model_path)
40
41             network = Network(tmp_model_path)
42             Network.reshape(network.ie_network, batch_size)
43             return network
44         finally:
45             if tmp_model_dir:
46                 shutil.rmtree(tmp_model_dir)
47
48     def __init__(self, model_path: str, weights_path: str=None):
49         if model_path is None:
50             raise ValueError("model_path is None")
51
52         self._model_path = model_path
53         self._weights_path = weights_path if weights_path else Path.get_weights(model_path)
54         self._ie_network = None
55
56     def __enter__(self):
57         return self
58
59     def __exit__(self, type, value, traceback):
60         self.release()
61
62     def release(self):
63         if self._ie_network:
64             del self._ie_network
65             self._ie_network = None
66
67     @staticmethod
68     def reshape(ie_network: ie.IENetwork, batch_size: int) -> ie.IENetwork:
69         if batch_size and batch_size != ie_network.batch_size:
70             new_shapes = {}
71             for input_layer_name, input_layer in ie_network.inputs.items():
72                 layout = input_layer.layout
73                 if layout == 'C':
74                     new_shape = (input_layer.shape[0],)
75                 elif layout == 'NC':
76                     new_shape = (batch_size, input_layer.shape[1])
77                 else:
78                     raise ValueError("not supported layout '{}'".format(layout))                    
79                 new_shapes[input_layer_name] = new_shape
80             ie_network.reshape(new_shapes)
81         return ie_network
82
83     @property
84     def model_path(self) -> str:
85         return self._model_path
86
87     @property
88     def weights_path(self) -> str:
89         return self._weights_path
90
91     @property
92     def ie_network(self) -> ie.IENetwork:
93         if not self._ie_network:
94             self._ie_network = ie.IENetwork(self._model_path, self._weights_path)
95         return self._ie_network
96
97     def set_quantization_levels(self, quantization_level: dict):
98         for layer_name, value in quantization_level.items():
99             params = self.ie_network.layers[layer_name].params
100             params["quantization_level"] = value
101             self.ie_network.layers[layer_name].params = params
102
103     def set_statistics(self, statistics: dict):
104         network_stats = {}
105         for layer_name, node_statistic in statistics.items():
106             network_stats[layer_name] = ie.LayerStats(min=tuple(node_statistic.min_outputs),
107                                                       max=tuple(node_statistic.max_outputs))
108         self.ie_network.stats.update(network_stats)
109
110     def serialize(self, model_path: str, weights_path: str=None):
111         self.ie_network.serialize(model_path, weights_path if weights_path else Path.get_weights(model_path))