2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
16 from copy import deepcopy
21 from collections import namedtuple
23 from mo.front.common.partial_infer.utils import int64_array
24 from mo.graph.connection import Connection
25 from mo.utils.error import Error
29 def __init__(self, node, idx: int, type: str):
30 if type not in ['in', 'out']:
31 raise Error("Inappropriate port type: {}".format(type))
33 # We use self.__dict__ only to not to call __setattr__ method from __init__ function
34 self.__dict__['node'] = node
35 self.__dict__['idx'] = idx
36 self.__dict__['type'] = type
37 self.__dict__['data'] = namedtuple('Data', ['get_value', 'get_shape', 'get_attr', 'set_value', 'set_shape', 'set_attr', 'has_valid'])
39 self.data.get_shape = self._get_shape
40 self.data.set_shape = self._set_shape
42 self.data.get_value = self._get_value
43 self.data.set_value = self._set_value
45 self.data.get_attr = self._get_attr
46 self.data.set_attr = self._set_attr
48 self.data.has_valid = self._has_valid
50 def __eq__(self, other):
52 self.__class__ == other.__class__ and
53 self.node.graph == other.node.graph and
54 self.node.id == other.node.id and
55 self.type == other.type and
59 def __deepcopy__(self, memo):
61 result = cls.__new__(cls)
62 memo[id(self)] = result
63 for k, v in self.__dict__.items():
64 result.__dict__[k] = v if k in ['graph', 'node'] else deepcopy(v)
67 def __setattr__(self, key, value):
68 edge = self.node.in_edge(self.idx) if self.type == 'in' else self.node.out_edge(self.idx)
71 def __getattr__(self, item):
72 edge = self.node.in_edge(self.idx) if self.type == 'in' else self.node.out_edge(self.idx)
74 def _create_data_if_necessary(self):
75 if self.node.graph.stage == 'front':
76 raise Error("_create_data_if_necessary method is not applicable for front Graph phase!")
78 raise Error("_create_data_if_necessary method is not applicable for 'in' Port type!")
80 if self.idx not in self.node.out_nodes():
81 from mo.ops.op import Op
82 Op.create_data_node(self.node.graph, self.node, out_port=self.idx)
83 self.node['need_shape_inference'] = True
84 return self.node.out_node(self.idx)
87 if self.node.graph.stage == 'front':
91 return self.node.in_node(self.idx).shape
93 return self.node.out_node(self.idx).shape
95 def _set_shape(self, shape):
96 if self.node.graph.stage == 'front':
97 raise NotImplementedError("set_shape not implemented for front phase")
100 assert self.node.in_node(self.idx).value is None
101 self.node.in_node(self.idx).shape = int64_array(shape)
103 assert self.node.out_node(self.idx).value is None
104 self.node.out_node(self.idx).shape = int64_array(shape)
106 def _get_value(self):
107 if self.node.graph.stage == 'front':
110 if self.type == 'in':
111 if self.idx in self.node.in_nodes() and self.node.in_node(self.idx).has_valid('value'):
112 return self.node.in_node(self.idx).value
114 if self.idx in self.node.out_nodes() and self.node.out_node(self.idx).has_valid('value'):
115 return self.node.out_node(self.idx).value
118 def _set_value(self, value):
119 if self.node.graph.stage == 'front':
120 raise Error("set_value is not applicable for graph front phase")
122 if self.type == 'in':
123 self.node.in_node(self.idx).value = value
124 self.node.in_node(self.idx).shape = int64_array(value.shape)
126 self.node.out_node(self.idx).value = value
127 self.node.out_node(self.idx).shape = int64_array(value.shape)
129 def _get_attr(self, item: str):
130 if self.node.graph.stage == 'front':
133 if self.type == 'in':
134 if self.idx in self.node.in_nodes() and self.node.in_node(self.idx).has_valid(item):
135 return self.node.in_node(self.idx)[item]
137 if self.idx in self.node.out_nodes() and self.node.out_node(self.idx).has_valid(item):
138 return self.node.out_node(self.idx)[item]
141 def _set_attr(self, item, value):
142 raise NotImplementedError()
144 def get_in_edge_attrs(self, data=False):
145 assert self.type == 'in'
146 for u, v, d in list(self.node.graph.in_edges(self.node.id, data=True)):
147 if d['in'] == self.idx:
148 edge_attrs = self.node.graph.get_edge_data(u, v)
149 for key in edge_attrs:
150 if edge_attrs[key]['in'] == self.idx:
152 return edge_attrs[key], u, v, key
154 return edge_attrs[key]
156 return None, None, None, None
160 def _has_valid(self, item):
161 if self.node.graph.stage == 'front':
162 raise NotImplementedError
164 if self.type == 'in':
165 if self.idx in self.node.in_nodes() and self.node.in_node(self.idx).has_valid(item):
168 if self.idx in self.node.out_nodes() and self.node.out_node(self.idx).has_valid(item):
172 def disconnected(self):
173 # This method returns False if port connected with some other port
174 # otherwise it returns True
176 if self.type == 'in':
177 return self.get_source() is None
179 return len(self.get_destinations()) == 0
181 def get_source(self):
182 # This method returns Port object that is producer (source) port for out port.
183 # In case if out port has no source port return None
185 assert self.type != 'out', "Can't get source for output port at {} node".format(self.node.name)
187 from mo.graph.graph import Node
191 if self.node.graph.stage == 'front':
192 for n, d in self.node.get_inputs():
193 if d['in'] == self.idx:
194 node = Node(self.node.graph, n)
195 producer_ports.append(node.out_port(d['out']))
200 if self.idx not in self.node.in_nodes():
203 in_data = self.node.in_node(self.idx)
204 for n, d in in_data.get_inputs():
205 node = Node(self.node.graph, n)
206 producer_ports.append(node.out_port(d['out']))
208 if len(producer_ports) != 1:
209 raise Error("Something happened with graph! data node has {} producers".format(len(producer_ports)))
211 return producer_ports[0]
213 def get_destination(self):
214 # This method returns Port that is consumer (destination) port for in port.
215 # In case if in port has no consumer return None
217 consumer_ports = self.get_destinations()
218 if not consumer_ports:
221 if len(consumer_ports) > 1:
222 raise Error("The number of destinations for {} node at {} port is {}".format(self.node.name,
224 len(consumer_ports)))
225 return consumer_ports[0]
227 def get_destinations(self):
228 assert self.type != 'in', "Can't get destinations for input port at {} node".format(self.node.name)
230 from mo.graph.graph import Node
232 if self.node.graph.stage == 'front':
233 producer_node = self.node
235 # In case if node has no output data node in given port, we return None
236 if self.idx not in self.node.out_nodes():
238 producer_node = self.node.out_node(self.idx)
240 for n, d in producer_node.get_outputs():
241 node = Node(self.node.graph, n)
242 consumer_ports.append(node.in_port(d['in']))
243 return consumer_ports
245 def disconnect(self):
246 if self.type == 'out':
247 consumer_ports = self.get_destinations()
248 if self.node.graph.stage == 'front':
249 for port in consumer_ports:
250 self.node.graph.remove_edge(self.node.id, port.node.id)
252 for port in consumer_ports:
253 self.node.graph.remove_edge(port.node.in_node(port.idx).id, port.node.id)
255 source_port = self.get_source()
256 if source_port is None:
258 for u, v, d in list(self.node.graph.in_edges(self.node.id, data=True)):
259 if d['in'] == self.idx:
260 for key in self.node.graph.get_edge_data(u, v):
261 if self.node.graph.get_edge_data(u, v)[key]['in'] == self.idx:
262 self.node.graph.remove_edge(u, v, key=key)
265 def get_connection(self):
266 if self.type == 'in':
267 return Connection(self.node.graph, self.get_source(), [self])
269 return Connection(self.node.graph, self, self.get_destinations())
271 def connect(self, port):
272 if self.type == 'in':
273 self.get_connection().set_source(port)
275 self.get_connection().add_destination(port)