2 Copyright (C) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
22 import openvino.inference_engine as ie
23 from .utils.path import Path
28 def reload(model_path: str, statistics = None, quantization_levels: dict() = None, batch_size: int = None):
31 with Network(model_path) as network:
33 network.set_statistics(statistics)
34 if quantization_levels:
35 network.set_quantization_levels(quantization_levels)
37 tmp_model_dir = tempfile.mkdtemp(".model")
38 tmp_model_path = os.path.join(tmp_model_dir, ntpath.basename(model_path))
39 network.serialize(tmp_model_path)
41 network = Network(tmp_model_path)
42 Network.reshape(network.ie_network, batch_size)
46 shutil.rmtree(tmp_model_dir)
48 def __init__(self, model_path: str, weights_path: str=None):
49 if model_path is None:
50 raise ValueError("model_path is None")
52 self._model_path = model_path
53 self._weights_path = weights_path if weights_path else Path.get_weights(model_path)
54 self._ie_network = None
59 def __exit__(self, type, value, traceback):
65 self._ie_network = None
68 def reshape(ie_network: ie.IENetwork, batch_size: int) -> ie.IENetwork:
69 if batch_size and batch_size != ie_network.batch_size:
71 for input_layer_name, input_layer in ie_network.inputs.items():
72 layout = input_layer.layout
74 new_shape = (input_layer.shape[0],)
76 new_shape = (batch_size, input_layer.shape[1])
78 raise ValueError("not supported layout '{}'".format(layout))
79 new_shapes[input_layer_name] = new_shape
80 ie_network.reshape(new_shapes)
84 def model_path(self) -> str:
85 return self._model_path
88 def weights_path(self) -> str:
89 return self._weights_path
92 def ie_network(self) -> ie.IENetwork:
93 if not self._ie_network:
94 self._ie_network = ie.IENetwork(self._model_path, self._weights_path)
95 return self._ie_network
97 def set_quantization_levels(self, quantization_level: dict):
98 for layer_name, value in quantization_level.items():
99 params = self.ie_network.layers[layer_name].params
100 params["quantization_level"] = value
101 self.ie_network.layers[layer_name].params = params
103 def set_statistics(self, statistics: dict):
105 for layer_name, node_statistic in statistics.items():
106 network_stats[layer_name] = ie.LayerStats(min=tuple(node_statistic.min_outputs),
107 max=tuple(node_statistic.max_outputs))
108 self.ie_network.stats.update(network_stats)
110 def serialize(self, model_path: str, weights_path: str=None):
111 self.ie_network.serialize(model_path, weights_path if weights_path else Path.get_weights(model_path))