Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / utils / network_info.py
1 """
2 Copyright (C) 2018-2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 import xmltodict
18 from typing import List
19
20 from .layer import Layer
21 from .edge import Edge
22 from .connection import Connection
23
24
25 # TODO: custom implementation:
26 # 1. get in/out layers
27 # 2. add_layer
28 class NetworkInfo:
29     def __init__(self, model_path: str):
30         
31         model_content = None
32         with open(model_path, 'r') as mode_file:
33             model_content = mode_file.read()
34
35         model_xml = xmltodict.parse(model_content, attr_prefix='')
36         if 'net' not in model_xml:
37             raise ValueError("IR file '{}' format is not correct".format(model_path))
38
39         self._model = model_xml['net']
40
41         # TODO: move to private method
42         ordered_edges = self._model['edges']['edge']
43         self._edges_by_from_layer = dict()
44         self._edges_by_to_layer = dict()
45         for ordered_edge in ordered_edges:
46             from_layer = int(ordered_edge['from-layer'])
47             to_layer = int(ordered_edge['to-layer'])
48
49             edge = Edge(ordered_edge)
50
51             if from_layer not in self._edges_by_from_layer:
52                 self._edges_by_from_layer[from_layer] = list()
53             self._edges_by_from_layer[from_layer].append(edge)
54
55             if to_layer not in self._edges_by_to_layer:
56                 self._edges_by_to_layer[to_layer] = list()
57             self._edges_by_to_layer[to_layer].append(edge)
58
59         # TODO: move to private method
60         ordered_layers = self._model['layers']['layer']
61         self._layer_by_id = dict()
62         self._layer_by_name = dict()
63         for ordered_layer in ordered_layers:
64             layer = Layer(ordered_layer)
65             self._layer_by_id[int(ordered_layer['id'])] = layer
66             self._layer_by_name[layer.name] = layer
67
68         # TODO: move to private method
69         for layer_id, layer in self._layer_by_id.items():
70             input_edges = self._edges_by_to_layer[layer_id] if layer_id in self._edges_by_to_layer else list()
71             inputs = list()
72             for edge in input_edges:
73                 if edge.from_layer not in self._layer_by_id:
74                     raise ValueError("layer with id {} was not found".format(edge.from_layer))
75
76                 # inputs.append(self._layer_by_id[edge.from_layer])
77                 from_layer = self._layer_by_id[edge.from_layer]
78                 inputs.append(Connection(edge=edge, port=layer.input_ports[edge.to_port], layer=from_layer))
79
80             output_edges = self._edges_by_from_layer[layer_id] if layer_id in self._edges_by_from_layer else list()
81             outputs = list()
82             for edge in output_edges:
83                 if edge.to_layer not in self._layer_by_id:
84                     raise ValueError("layer with id {} was not found".format(edge.to_layer))
85
86                 # outputs.append(self._layer_by_id[edge.to_layer])
87                 to_layer = self._layer_by_id[edge.to_layer]
88                 outputs.append(Connection(edge=edge, port=layer.output_ports[edge.from_port], layer=to_layer))
89
90             layer.init(inputs, outputs)
91
92         pass
93
94     def get_layer_names(self, layer_types: List[str]) -> List[str]:
95         skipped = []
96         if layer_types:
97             for layer in self._layer_by_name.values():
98                 if layer.type in layer_types:
99                     skipped.append(layer.name)
100         return skipped
101
102     @property
103     def layers(self) -> int:
104         return self._layer_by_id
105
106     def get_layer(self, layer_name: str) -> Layer:
107         return self._layer_by_name[layer_name]
108
109     def explore_inputs(self, layer: Layer, expected_input_types: List[str]) -> bool:
110         for layer_input in layer.inputs:
111             if layer_input.layer.type not in expected_input_types:
112                 return False
113             if not self.explore_inputs(layer_input.layer, expected_input_types):
114                 return False
115         return True
116
117     @property
118     def inputs(self):
119         inputs = dict()
120         for id, layer in self.layers.items():
121             if layer.type == 'Input':
122                 inputs[id] = layer
123         return inputs