2 Copyright (C) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from typing import List
20 from .layer import Layer
21 from .edge import Edge
22 from .connection import Connection
25 # TODO: custom implementation:
26 # 1. get in/out layers
29 def __init__(self, model_path: str):
32 with open(model_path, 'r') as mode_file:
33 model_content = mode_file.read()
35 model_xml = xmltodict.parse(model_content, attr_prefix='')
36 if 'net' not in model_xml:
37 raise ValueError("IR file '{}' format is not correct".format(model_path))
39 self._model = model_xml['net']
41 # TODO: move to private method
42 ordered_edges = self._model['edges']['edge']
43 self._edges_by_from_layer = dict()
44 self._edges_by_to_layer = dict()
45 for ordered_edge in ordered_edges:
46 from_layer = int(ordered_edge['from-layer'])
47 to_layer = int(ordered_edge['to-layer'])
49 edge = Edge(ordered_edge)
51 if from_layer not in self._edges_by_from_layer:
52 self._edges_by_from_layer[from_layer] = list()
53 self._edges_by_from_layer[from_layer].append(edge)
55 if to_layer not in self._edges_by_to_layer:
56 self._edges_by_to_layer[to_layer] = list()
57 self._edges_by_to_layer[to_layer].append(edge)
59 # TODO: move to private method
60 ordered_layers = self._model['layers']['layer']
61 self._layer_by_id = dict()
62 self._layer_by_name = dict()
63 for ordered_layer in ordered_layers:
64 layer = Layer(ordered_layer)
65 self._layer_by_id[int(ordered_layer['id'])] = layer
66 self._layer_by_name[layer.name] = layer
68 # TODO: move to private method
69 for layer_id, layer in self._layer_by_id.items():
70 input_edges = self._edges_by_to_layer[layer_id] if layer_id in self._edges_by_to_layer else list()
72 for edge in input_edges:
73 if edge.from_layer not in self._layer_by_id:
74 raise ValueError("layer with id {} was not found".format(edge.from_layer))
76 # inputs.append(self._layer_by_id[edge.from_layer])
77 from_layer = self._layer_by_id[edge.from_layer]
78 inputs.append(Connection(edge=edge, port=layer.input_ports[edge.to_port], layer=from_layer))
80 output_edges = self._edges_by_from_layer[layer_id] if layer_id in self._edges_by_from_layer else list()
82 for edge in output_edges:
83 if edge.to_layer not in self._layer_by_id:
84 raise ValueError("layer with id {} was not found".format(edge.to_layer))
86 # outputs.append(self._layer_by_id[edge.to_layer])
87 to_layer = self._layer_by_id[edge.to_layer]
88 outputs.append(Connection(edge=edge, port=layer.output_ports[edge.from_port], layer=to_layer))
90 layer.init(inputs, outputs)
94 def get_layer_names_by_types(self, layer_types: List[str]) -> List[str]:
97 for layer in self._layer_by_name.values():
98 if layer.type in layer_types:
99 layer_names.append(layer.name)
103 def layers(self) -> int:
104 return self._layer_by_id
106 def get_layer(self, layer_name: str) -> Layer:
107 return self._layer_by_name[layer_name]
109 def explore_inputs(self, layer: Layer, expected_input_types: List[str]) -> bool:
110 for layer_input in layer.inputs:
111 if layer_input.layer.type not in expected_input_types:
113 if not self.explore_inputs(layer_input.layer, expected_input_types):
120 for id, layer in self.layers.items():
121 if layer.type == 'Input':