Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / graph / port.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 from copy import deepcopy
17
18 import numpy as np
19 import networkx as nx
20
21 from collections import namedtuple
22
23 from mo.front.common.partial_infer.utils import int64_array
24 from mo.graph.connection import Connection
25 from mo.utils.error import Error
26
27
28 class Port:
29     def __init__(self, node, idx: int, type: str):
30         if type not in ['in', 'out']:
31             raise Error("Inappropriate port type: {}".format(type))
32
33         # We use self.__dict__ only to not to call __setattr__ method from __init__ function
34         self.__dict__['node'] = node
35         self.__dict__['idx'] = idx
36         self.__dict__['type'] = type
37         self.__dict__['data'] = namedtuple('Data', ['get_value', 'get_shape', 'get_attr', 'set_value', 'set_shape', 'set_attr', 'has_valid'])
38
39         self.data.get_shape = self._get_shape
40         self.data.set_shape = self._set_shape
41
42         self.data.get_value = self._get_value
43         self.data.set_value = self._set_value
44
45         self.data.get_attr = self._get_attr
46         self.data.set_attr = self._set_attr
47
48         self.data.has_valid = self._has_valid
49
50     def __eq__(self, other):
51         return (
52             self.__class__ == other.__class__ and
53             self.node.graph == other.node.graph and
54             self.node.id == other.node.id and
55             self.type == other.type and
56             self.idx == other.idx
57         )
58
59     def __deepcopy__(self, memo):
60         cls = self.__class__
61         result = cls.__new__(cls)
62         memo[id(self)] = result
63         for k, v in self.__dict__.items():
64             result.__dict__[k] = v if k in ['graph', 'node'] else deepcopy(v)
65         return result
66
67     def __setattr__(self, key, value):
68         edge = self.node.in_edge(self.idx) if self.type == 'in' else self.node.out_edge(self.idx)
69         edge[key] = value
70
71     def __getattr__(self, item):
72         edge = self.node.in_edge(self.idx) if self.type == 'in' else self.node.out_edge(self.idx)
73
74     def _create_data_if_necessary(self):
75         if self.node.graph.stage == 'front':
76             raise Error("_create_data_if_necessary method is not applicable for front Graph phase!")
77         if self.type == 'in':
78             raise Error("_create_data_if_necessary method is not applicable for 'in' Port type!")
79
80         if self.idx not in self.node.out_nodes():
81             from mo.ops.op import Op
82             Op.create_data_node(self.node.graph, self.node, out_port=self.idx)
83             self.node['need_shape_inference'] = True
84         return self.node.out_node(self.idx)
85
86     def _get_shape(self):
87         if self.node.graph.stage == 'front':
88             return None
89         else:
90             if self.type == 'in':
91                 return self.node.in_node(self.idx).shape
92             else:
93                 return self.node.out_node(self.idx).shape
94
95     def _set_shape(self, shape):
96         if self.node.graph.stage == 'front':
97             raise NotImplementedError("set_shape not implemented for front phase")
98         else:
99             if self.type == 'in':
100                 assert self.node.in_node(self.idx).value is None
101                 self.node.in_node(self.idx).shape = int64_array(shape)
102             else:
103                 assert self.node.out_node(self.idx).value is None
104                 self.node.out_node(self.idx).shape = int64_array(shape)
105
106     def _get_value(self):
107         if self.node.graph.stage == 'front':
108             return None
109         else:
110             if self.type == 'in':
111                 if self.idx in self.node.in_nodes() and self.node.in_node(self.idx).has_valid('value'):
112                     return self.node.in_node(self.idx).value
113             else:
114                 if self.idx in self.node.out_nodes() and self.node.out_node(self.idx).has_valid('value'):
115                     return self.node.out_node(self.idx).value
116         return None
117
118     def _set_value(self, value):
119         if self.node.graph.stage == 'front':
120             raise Error("set_value is not applicable for graph front phase")
121         else:
122             if self.type == 'in':
123                 self.node.in_node(self.idx).value = value
124                 self.node.in_node(self.idx).shape = int64_array(value.shape)
125             else:
126                 self.node.out_node(self.idx).value = value
127                 self.node.out_node(self.idx).shape = int64_array(value.shape)
128
129     def _get_attr(self, item: str):
130         if self.node.graph.stage == 'front':
131             return None
132         else:
133             if self.type == 'in':
134                 if self.idx in self.node.in_nodes() and self.node.in_node(self.idx).has_valid(item):
135                     return self.node.in_node(self.idx)[item]
136             else:
137                 if self.idx in self.node.out_nodes() and self.node.out_node(self.idx).has_valid(item):
138                     return self.node.out_node(self.idx)[item]
139         return None
140
141     def _set_attr(self, item, value):
142         raise NotImplementedError()
143
144     def get_in_edge_attrs(self, data=False):
145         assert self.type == 'in'
146         for u, v, d in list(self.node.graph.in_edges(self.node.id, data=True)):
147             if d['in'] == self.idx:
148                 edge_attrs = self.node.graph.get_edge_data(u, v)
149                 for key in edge_attrs:
150                     if edge_attrs[key]['in'] == self.idx:
151                         if data:
152                             return edge_attrs[key], u, v, key
153                         else:
154                             return edge_attrs[key]
155         if data:
156             return None, None, None, None
157         else:
158             return None
159
160     def _has_valid(self, item):
161         if self.node.graph.stage == 'front':
162             raise NotImplementedError
163         else:
164             if self.type == 'in':
165                 if self.idx in self.node.in_nodes() and self.node.in_node(self.idx).has_valid(item):
166                     return True
167             else:
168                 if self.idx in self.node.out_nodes() and self.node.out_node(self.idx).has_valid(item):
169                     return True
170         return False
171
172     def disconnected(self):
173         # This method returns False if port connected with some other port
174         # otherwise it returns True
175
176         if self.type == 'in':
177             return self.get_source() is None
178         else:
179             return len(self.get_destinations()) == 0
180
181     def get_source(self):
182         # This method returns Port object that is producer (source) port for out port.
183         # In case if out port has no source port return None
184
185         assert self.type != 'out', "Can't get source for output port at {} node".format(self.node.name)
186
187         from mo.graph.graph import Node
188         producer_ports = []
189
190         has_producer = False
191         if self.node.graph.stage == 'front':
192             for n, d in self.node.get_inputs():
193                 if d['in'] == self.idx:
194                     node = Node(self.node.graph, n)
195                     producer_ports.append(node.out_port(d['out']))
196                     has_producer = True
197             if not has_producer:
198                 return None
199         else:
200             if self.idx not in self.node.in_nodes():
201                 return None
202
203             in_data = self.node.in_node(self.idx)
204             for n, d in in_data.get_inputs():
205                 node = Node(self.node.graph, n)
206                 producer_ports.append(node.out_port(d['out']))
207
208         if len(producer_ports) != 1:
209             raise Error("Something happened with graph! data node has {} producers".format(len(producer_ports)))
210
211         return producer_ports[0]
212
213     def get_destination(self):
214         # This method returns Port that is consumer (destination) port for in port.
215         # In case if in port has no consumer return None
216
217         consumer_ports = self.get_destinations()
218         if not consumer_ports:
219             return None
220
221         if len(consumer_ports) > 1:
222             raise Error("The number of destinations for {} node at {} port is {}".format(self.node.name,
223                                                                                          self.idx,
224                                                                                          len(consumer_ports)))
225         return consumer_ports[0]
226
227     def get_destinations(self):
228         assert self.type != 'in', "Can't get destinations for input port at {} node".format(self.node.name)
229
230         from mo.graph.graph import Node
231         consumer_ports = []
232         if self.node.graph.stage == 'front':
233             producer_node = self.node
234         else:
235             # In case if node has no output data node in given port, we return None
236             if self.idx not in self.node.out_nodes():
237                 return []
238             producer_node = self.node.out_node(self.idx)
239
240         for n, d in producer_node.get_outputs():
241             node = Node(self.node.graph, n)
242             consumer_ports.append(node.in_port(d['in']))
243         return consumer_ports
244
245     def disconnect(self):
246         if self.type == 'out':
247             consumer_ports = self.get_destinations()
248             if self.node.graph.stage == 'front':
249                 for port in consumer_ports:
250                     self.node.graph.remove_edge(self.node.id, port.node.id)
251             else:
252                 for port in consumer_ports:
253                     self.node.graph.remove_edge(port.node.in_node(port.idx).id, port.node.id)
254         else:
255             source_port = self.get_source()
256             if source_port is None:
257                 return
258             for u, v, d in list(self.node.graph.in_edges(self.node.id, data=True)):
259                 if d['in'] == self.idx:
260                     for key in self.node.graph.get_edge_data(u, v):
261                         if self.node.graph.get_edge_data(u, v)[key]['in'] == self.idx:
262                             self.node.graph.remove_edge(u, v, key=key)
263                             return
264
265     def get_connection(self):
266         if self.type == 'in':
267             return Connection(self.node.graph, self.get_source(), [self])
268         else:
269             return Connection(self.node.graph, self, self.get_destinations())
270
271     def connect(self, port):
272         if self.type == 'in':
273             self.get_connection().set_source(port)
274         else:
275             self.get_connection().add_destination(port)