2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from collections import namedtuple
24 from mo.front.extractor import add_attrs_props
25 from mo.front.extractor import update_ie_fields
26 from mo.graph.graph import Node, Graph
27 from mo.graph.port import Port
28 from mo.utils import class_registration
29 from mo.utils.error import Error
35 # Add the derived class to excluded_classes if one should not be registered in registered_ops
38 def __init__(self, graph: Graph, attrs1: dict = None, attrs2: dict = None):
41 self.ir_version = graph.graph['ir_version']
43 self.ir_version = None
49 self.default_backend_attrs = []
50 if attrs1 is not None:
51 self.attrs.update(attrs1)
52 if attrs2 is not None:
53 self.attrs.update(attrs2)
55 def add_node(self, attrs: dict = None):
57 new_attrs.update(self.attrs)
59 new_attrs.update(attrs)
60 id_prefix = new_attrs['name'] if 'name' in new_attrs else ''
61 id = self.graph.unique_id(id_prefix)
62 new_attrs['name'] = id
63 new_attrs = add_attrs_props(new_attrs)
64 update_ie_fields(new_attrs, self.ir_version)
65 self.substitute_ie_attrs(new_attrs)
66 self.graph.add_node(id, **new_attrs)
68 node = Node(self.graph, id)
71 def substitute_ie_attrs(self, new_attrs: dict):
73 Replace standard list of attribute in layer/data by attributes
74 delivered by backend_attrs
76 backend_attrs_mapping = {
77 None: self.backend_attrs,
78 5: self.backend_attrs,
79 4: self.backend_attrs,
80 3: self.backend_attrs,
81 2: self.backend_attrs_v2
84 if self.ir_version not in backend_attrs_mapping.keys():
85 raise Error("Unrecognized IR version was specified: {}".format(self.ir_version))
90 [('id', lambda node: node.node), 'name', 'precision', 'type'],
92 ('data', backend_attrs_mapping[self.ir_version]() + self.default_backend_attrs, []),
98 def extract_port(node_port):
99 if isinstance(node_port, tuple):
105 # 'data' nodes do not have 'out' edge attibute but always has one output
106 out_ids = [attr['out'] for _, __, attr in node.graph.out_edges(node.id, data=True) if 'out' in attr]
107 if len(set(out_ids)) > 1 and not isinstance(node_port, tuple):
108 raise Error('Node {} has more than one outputs. Provide output port explicitly. '.format(node.name))
111 def create_node_on_port(self, node: Node, out_port: int, attrs: dict = None, edge_attrs: dict = None):
113 Removes an edge, that is connected to nodes out_port. Creates new_node with attrs attributes and
114 connects it to node by edge that stores the same information as cutted edge.
115 :param node: Input node, to cut the edge from
116 :param out_port: output port of edge to cut
117 :param attrs: attributes of new node
118 :param edge_attrs: attributes to be changed/added to new edge
119 :return: Node instance of created new_node
121 if edge_attrs is None:
122 edge_attrs = {'in': 0}
123 prev_edge_attrs = copy.deepcopy(node.out_edge(out_port))
124 prev_edge_attrs.update(edge_attrs)
125 new_edge_attrs = prev_edge_attrs
128 new_node = self.add_node(attrs)
129 self.graph.add_edge(node.id, new_node.id, **new_edge_attrs)
132 def create_node(self, inputs: list = None, attrs: dict = None, edge_attrs: dict = None):
133 # TODO pass also edge attributes to copy to newly created edges
134 # TODO attrs should be matched with attrs()
135 if inputs is not None:
136 inputs = [Op.extract_port(inp) for inp in inputs]
141 new_node = self.add_node(attrs)
142 # Missed careful handling of debug information
143 for i, inp in enumerate(inputs):
144 edge_attr = {'in': i, 'out': inp[1],
145 'in_attrs': ['in', 'permutation'],
146 'out_attrs': ['out', 'permutation'],
147 'data_attrs': []} if not inp[0].has_valid('kind') or inp[0].kind == 'op' \
148 else {'in': i, 'in_attrs': ['in', 'permutation']}
149 if edge_attrs is not None:
150 edge_attr.update(edge_attrs)
151 self.graph.add_edge(inp[0].id, new_node.id, **edge_attr)
154 def create_node_with_data(self, inputs: list = None, attrs: dict = None,
155 data_nodes: [Node, np.ndarray, list] = None, edge_attrs: list = None):
157 Creates a new node with given inputs and attrs and also creates data node that
158 holds the op output value. Inputs should be data nodes (not op nodes).
159 Work for ops with a single output port only.
160 Edge attributes in edge_attrs go in order of items in 'inputs'
166 # No need to extract port, because input node should be a data node,
167 # so there is no choice.
168 new_op_node = self.add_node(attrs)
170 # TODO Preserve debug infor
171 inputs_with_edge_attrs = []
172 for i, inp in enumerate(inputs):
175 edge_attr = {'in': i}
176 if edge_attrs is not None and i < len(edge_attrs):
177 edge_attr.update(edge_attrs[i])
178 inputs_with_edge_attrs.append((inp.id, new_op_node.id, edge_attr))
180 self.graph.add_edges_from(inputs_with_edge_attrs)
182 # TODO: Extend to the case when multiple output ports
183 old_data_value = [None]
184 old_data_shape = [None]
185 if data_nodes is None:
186 data_node = self.graph.unique_id()
187 self.graph.add_node(data_node, **add_attrs_props(
188 dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
190 data_nodes = [Node(self.graph, data_node)]
192 if type(data_nodes) not in [list, np.ndarray]:
193 data_nodes = [data_nodes]
194 old_data_value = [data_node.value.copy() if data_node.has_valid('value') else None for data_node in
196 old_data_shape = [data_node.shape.copy() if data_node.has_valid('shape') else None for data_node in
198 for id, data_node in enumerate(data_nodes):
199 self.graph.add_edges_from([(new_op_node.id, data_node.id, {'out': id})])
201 if new_op_node.has_valid('infer'):
202 if log.getLogger().isEnabledFor(log.DEBUG):
203 log.debug('Start running infer function for individual op node with attributes: {}'
204 ''.format(str(new_op_node)))
205 new_op_node.infer(new_op_node)
206 assert all(old_value is None for old_value in old_data_value) or all(
207 [np.array_equal(old_data_value[id], data_node.value) for id, data_node in enumerate(data_nodes)])
208 assert all(old_shape is None for old_shape in old_data_shape) or all(
209 [np.array_equal(old_data_shape[id], data_node.shape) for id, data_node in enumerate(data_nodes)]), \
210 "After re-inference of {} node, old and new shapes do not match. Old shapes: {}, new shapes: {}.".format(
211 new_op_node.soft_get('name'),
212 [old_data_shape[id] for id in range(len(data_nodes))],
213 [data_node.shape for data_node in data_nodes])
214 for data_node in data_nodes:
215 if log.getLogger().isEnabledFor(log.DEBUG):
217 'Finished running infer function, data nodes attributes: {}'.format(data_node))
218 return data_nodes[0] if len(data_nodes) == 1 else data_nodes
221 def create_data_node(graph: Graph, op_node: Node, attrs: dict = None, edge_attrs: dict = None, out_port=0):
222 assert op_node is not None and op_node.kind == 'op'
223 assert len(op_node.out_nodes()) == 0
227 data_node = graph.unique_id(op_node.id)
228 defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
230 defaul_attrs.update(attrs)
231 graph.add_node(data_node, **add_attrs_props(defaul_attrs))
232 data_node = Node(graph, data_node)
233 if edge_attrs is not None:
234 graph.add_edges_from([(op_node.id, data_node.id, {'out': out_port, **edge_attrs})])
236 graph.add_edges_from([(op_node.id, data_node.id, {'out': out_port})])
240 def _create_data_node(graph: Graph, name: str, attrs: dict = None):
244 data_node = graph.unique_id(name)
245 defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
247 defaul_attrs.update(attrs)
248 graph.add_node(data_node, **add_attrs_props(defaul_attrs))
249 data_node = Node(graph, data_node)
253 def create_input_data_node(graph: Graph, name: str, value: np.array, attrs: dict = {}):
254 data_node = graph.unique_id(name)
255 defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=np.array(value),
256 shape=np.array(value.shape),
257 data_type=None, infer=None)
258 defaul_attrs.update(attrs)
259 graph.add_node(data_node, **add_attrs_props(defaul_attrs))
260 return Node(graph, data_node)
263 def create_and_connect_input_data_node(graph: Graph, op_node: Node, attrs: dict = None, edge_attrs: dict = None):
264 assert op_node is not None and op_node.kind == 'op'
267 if edge_attrs is None:
270 data_node = graph.unique_id(op_node.id)
271 defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
273 defaul_attrs.update(attrs)
274 graph.add_node(data_node, **add_attrs_props(defaul_attrs))
275 data_node = Node(graph, data_node)
276 graph.add_edges_from([(data_node.id, op_node.id, edge_attrs)])
279 def update_node(self, node: Node, attrs: dict = None):
281 Updates/creates new attributes in node based on self.attrs and attrs.
284 new_attrs.update(self.attrs)
286 new_attrs.update(attrs)
287 new_attrs = add_attrs_props(new_attrs)
288 update_ie_fields(new_attrs, self.ir_version)
289 self.substitute_ie_attrs(new_attrs)
290 for k, v in new_attrs.items():
294 def update_node_stat(cls, node: Node, attrs: dict = None):
297 op = cls(node.graph, attrs)
300 def supported_attrs(self):
302 Attributes that user should/can set for the operation
306 def backend_attrs(self):
308 Attributes that will be translated to back-end IR
310 return self.supported_attrs()
312 def backend_attrs_v2(self):
313 return self.backend_attrs()
316 def get_op_class_by_name(name: str):
317 return __class__.registered_ops[name]
321 return class_registration.ClassType.OP
324 def expand_node_shape(node: Node, dims_to_add):
325 if node is None or not node.has_valid('value'):
327 for idx in range(dims_to_add):
328 node.value = np.expand_dims(node.value, axis=-1)
329 node.shape = np.array(node.value.shape)
333 Permutation = namedtuple('Permutation', ['perm', 'inv'])
334 Attr = namedtuple('Attr', ['name', 'port', 'func'])
336 common_permutation = lambda node, permutation, attr: node[attr][permutation.perm]
337 common_permutation_inv = lambda node, permutation, attr: permutation.inv[node[attr]]
339 # List of default permutations
340 common_attrs_permutation = {
341 'dim': common_permutation,
342 'pad': common_permutation,
343 'pads': common_permutation,
344 'shape': common_permutation,
345 'order': lambda node, permutation, attr: permutation.inv[node[attr][permutation.perm]],
346 'stride': common_permutation,
347 'window': common_permutation,
348 'dilation': common_permutation,
349 'kernel_shape': common_permutation,
350 'output_shape': common_permutation,
351 'slices': common_permutation,
352 'shrink_axis_mask': common_permutation,
353 'new_axis_mask': common_permutation,
355 'axis': common_permutation_inv,
356 'batch_dims': common_permutation_inv,
357 'channel_dims': common_permutation_inv,
358 'spatial_dims': common_permutation_inv,
360 'input_channel_dim': common_permutation_inv,
361 'output_channel_dim': common_permutation_inv,
362 'kernel_spatial_idx': common_permutation_inv,
363 'input_feature_channel': common_permutation_inv,
364 'output_feature_channel': common_permutation_inv,
368 def __attr(name, port, func=None):
370 if name in PermuteAttrs.common_attrs_permutation:
371 func = PermuteAttrs.common_attrs_permutation[name]
373 raise Error('Attr {} is missing in PermuteAttrs.common_attrs_permutation. Please update '
374 'common_attrs_permutation with permutation for your attribute!'.format(name))
376 if len(port.split(':')) != 2 or port.split(':')[0] not in ['input', 'output']:
377 raise Error("Attribute port {} for {} wasn't set correctly!".format(port, name))
379 return PermuteAttrs.Attr(name=name, port=port, func=func)
384 def update_attrs(self, attrs):
386 if not isinstance(attr, tuple) or len(attr) not in [2, 3]:
387 raise Error('attr object must be a tuple: (attribute_name, port) or (attribute_name, port, func)')
388 self.attrs.update({attr[0]: self.__attr(*attr)})
391 def permute_attrs(self, node):
392 # This function applies permutation for given node
393 for attr in self.attrs.keys():
394 name, port, func = self.attrs[attr]
395 node_type, port = port.split(':')
397 node_with_permutation = node.in_node(port) if node_type == 'input' else node.out_node(port)
399 if node_with_permutation.has_valid('permutation'):
400 permutation = node_with_permutation.permutation
401 if isinstance(permutation, type(lambda: 0)):
402 node[name] = func(node, permutation(node), name)
404 node[name] = func(node, permutation, name)
407 def create_permute_attrs(node, attrs=None):
408 # Create permute_attrs if not exists
409 if not node.has_valid('permute_attrs'):
410 node['permute_attrs'] = PermuteAttrs()
411 node['permute_attrs'].update_attrs(attrs)
414 def set_permutation(node1, node2, permutation):
415 # This function creates permutation on edge between node1->node2
416 edge_attrs = node1.graph.get_edge_data(node1.id, node2.id)[0]
417 if 'permutation' not in edge_attrs:
418 nx.set_edge_attributes(G=node1.graph,
419 values={(node1.id, node2.id, 0): permutation},
422 # If permutation exists we check that given and already set permutations are equal
423 if (edge_attrs['permutation'] is None and permutation is not None) or \
424 not np.array_equal(edge_attrs['permutation'], permutation):
425 raise Error('Permutation already exists in edge between {} and {}'.format(node1.name, node2.name))
428 def get_inverse_permutation(perm):
429 inv = [0] * len(perm)
430 # Create reverse permutation
431 for index, pos in enumerate(perm):
436 def get_nhwc_to_nchw_permutation(dims_number: int):
437 # This function returns permutation from NHWC to NCHW for given dims number
439 perm = [0, dims_number - 1, *[x for x in range(1, dims_number - 1)]] if dims_number > 1 else [x for x in range(
442 # Exclude 3D shapes from permutation process: identity permutation
443 perm = list(range(0, dims_number))
444 inv = PermuteAttrs.get_inverse_permutation(perm)
445 return PermuteAttrs.Permutation(perm=np.array(perm), inv=np.array(inv))
448 def get_nchw_to_nhwc_permutation(dims_number: int):
449 # This function returns permutation from NCHW to NHWC for given dims number
451 perm = [0, *[x for x in range(2, dims_number)], 1] if dims_number > 1 else [x for x in range(
454 # Exclude 3D shapes from permutation process: identity permutation
455 perm = list(range(0, dims_number))
456 inv = PermuteAttrs.get_inverse_permutation(perm)
457 return PermuteAttrs.Permutation(perm=np.array(perm), inv=np.array(inv))