2 Copyright (C) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from pathlib import Path as std_path
23 import openvino.inference_engine as ie
24 from .utils.path import Path
29 def reload(model_path: str, statistics = None, quantization_levels: dict = None, batch_size: int = None):
32 with Network(model_path) as network:
34 network.set_statistics(statistics)
35 if quantization_levels:
36 network.set_quantization_levels(quantization_levels)
38 tmp_model_dir = tempfile.mkdtemp(".model")
39 tmp_model_path = os.path.join(tmp_model_dir, ntpath.basename(model_path))
40 network.serialize(tmp_model_path)
42 network = Network(tmp_model_path)
43 Network.reshape(network.ie_network, batch_size)
47 shutil.rmtree(tmp_model_dir)
50 def serialize_tmp_model(model_path: str, statistics = None, quantization_levels: dict = None):
52 with Network(model_path) as network:
54 network.set_statistics(statistics)
55 if quantization_levels:
56 network.set_quantization_levels(quantization_levels)
58 tmp_model_dir = tempfile.mkdtemp(".model")
59 tmp_model_path = os.path.join(tmp_model_dir, ntpath.basename(model_path))
60 network.serialize(tmp_model_path)
63 print('Could not serialize temporary IR')
67 def rm_tmp_location(file_path):
69 pdir = std_path(file_path).parent
70 shutil.rmtree(str(pdir))
72 def __init__(self, model_path: str, weights_path: str=None):
73 if model_path is None:
74 raise ValueError("model_path is None")
76 self._model_path = model_path
77 self._weights_path = weights_path if weights_path else Path.get_weights(model_path)
78 self._ie_network = None
83 def __exit__(self, type, value, traceback):
89 self._ie_network = None
92 def reshape(ie_network: ie.IENetwork, batch_size: int) -> ie.IENetwork:
93 if batch_size and batch_size != ie_network.batch_size:
95 for input_layer_name, input_layer in ie_network.inputs.items():
96 layout = input_layer.layout
98 new_shape = (input_layer.shape[0],)
100 new_shape = (batch_size, input_layer.shape[1])
102 raise ValueError("not supported layout '{}'".format(layout))
103 new_shapes[input_layer_name] = new_shape
104 ie_network.reshape(new_shapes)
108 def model_path(self) -> str:
109 return self._model_path
112 def weights_path(self) -> str:
113 return self._weights_path
116 def ie_network(self) -> ie.IENetwork:
117 if not self._ie_network:
118 self._ie_network = ie.IENetwork(self._model_path, self._weights_path)
119 return self._ie_network
121 def set_quantization_levels(self, quantization_level: dict):
122 for layer_name, value in quantization_level.items():
123 params = self.ie_network.layers[layer_name].params
124 params["quantization_level"] = value
125 self.ie_network.layers[layer_name].params = params
127 def set_statistics(self, statistics: dict):
129 for layer_name, node_statistic in statistics.items():
130 network_stats[layer_name] = ie.LayerStats(min=tuple(node_statistic.min_outputs),
131 max=tuple(node_statistic.max_outputs))
132 self.ie_network.stats.update(network_stats)
134 def serialize(self, model_path: str, weights_path: str=None):
135 self.ie_network.serialize(model_path, weights_path if weights_path else Path.get_weights(model_path))