Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / ops / op.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import copy
18 import logging as log
19 from collections import namedtuple
20
21 import networkx as nx
22 import numpy as np
23
24 from mo.front.extractor import add_attrs_props
25 from mo.front.extractor import update_ie_fields
26 from mo.graph.graph import Node, Graph
27 from mo.graph.port import Port
28 from mo.utils import class_registration
29 from mo.utils.error import Error
30
31
32 class Op(object):
33     registered_ops = {}
34     registered_cls = []
35     # Add the derived class to excluded_classes if one should not be registered in registered_ops
36     excluded_classes = []
37
38     def __init__(self, graph: Graph, attrs1: dict = None, attrs2: dict = None):
39         self.graph = graph
40         try:
41             self.ir_version = graph.graph['ir_version']
42         except:
43             self.ir_version = None
44
45         self.attrs = {
46             'precision': "FP32",
47             'kind': 'op'
48         }
49         self.default_backend_attrs = []
50         if attrs1 is not None:
51             self.attrs.update(attrs1)
52         if attrs2 is not None:
53             self.attrs.update(attrs2)
54
55     def add_node(self, attrs: dict = None):
56         new_attrs = {}
57         new_attrs.update(self.attrs)
58         if attrs is not None:
59             new_attrs.update(attrs)
60         id_prefix = new_attrs['name'] if 'name' in new_attrs else ''
61         id = self.graph.unique_id(id_prefix)
62         new_attrs['name'] = id
63         new_attrs = add_attrs_props(new_attrs)
64         update_ie_fields(new_attrs, self.ir_version)
65         self.substitute_ie_attrs(new_attrs)
66         self.graph.add_node(id, **new_attrs)
67
68         node = Node(self.graph, id)
69         return node
70
71     def substitute_ie_attrs(self, new_attrs: dict):
72         """
73         Replace standard list of attribute in layer/data by attributes
74         delivered by backend_attrs
75         """
76         backend_attrs_mapping = {
77             None: self.backend_attrs,
78             5: self.backend_attrs,
79             4: self.backend_attrs,
80             3: self.backend_attrs,
81             2: self.backend_attrs_v2
82         }
83
84         if self.ir_version not in backend_attrs_mapping.keys():
85             raise Error("Unrecognized IR version was specified: {}".format(self.ir_version))
86
87         new_attrs.update({
88             'IE': [(
89                 'layer',
90                 [('id', lambda node: node.node), 'name', 'precision', 'type'],
91                 [
92                     ('data', backend_attrs_mapping[self.ir_version]() + self.default_backend_attrs, []),
93                     '@ports',
94                     '@consts'])]
95         })
96
97     @staticmethod
98     def extract_port(node_port):
99         if isinstance(node_port, tuple):
100             node = node_port[0]
101             port = node_port[1]
102         else:
103             node = node_port
104             port = 0
105         # 'data' nodes do not have 'out' edge attibute but always has one output
106         out_ids = [attr['out'] for _, __, attr in node.graph.out_edges(node.id, data=True) if 'out' in attr]
107         if len(set(out_ids)) > 1 and not isinstance(node_port, tuple):
108             raise Error('Node {} has more than one outputs. Provide output port explicitly. '.format(node.name))
109         return node, port
110
111     def create_node_on_port(self, node: Node, out_port: int, attrs: dict = None, edge_attrs: dict = None):
112         """
113         Removes an edge, that is connected to nodes out_port. Creates new_node with attrs attributes and
114         connects it to node by edge that stores the same information as cutted edge.
115         :param node: Input node, to cut the edge from
116         :param out_port: output port of edge to cut
117         :param attrs: attributes of new node
118         :param edge_attrs: attributes to be changed/added to new edge
119         :return: Node instance of created new_node
120         """
121         if edge_attrs is None:
122             edge_attrs = {'in': 0}
123         prev_edge_attrs = copy.deepcopy(node.out_edge(out_port))
124         prev_edge_attrs.update(edge_attrs)
125         new_edge_attrs = prev_edge_attrs
126         if attrs is None:
127             attrs = dict()
128         new_node = self.add_node(attrs)
129         self.graph.add_edge(node.id, new_node.id, **new_edge_attrs)
130         return new_node
131
132     def create_node(self, inputs: list = None, attrs: dict = None, edge_attrs: dict = None):
133         # TODO pass also edge attributes to copy to newly created edges
134         # TODO attrs should be matched with attrs()
135         if inputs is not None:
136             inputs = [Op.extract_port(inp) for inp in inputs]
137         else:
138             inputs = []
139         if attrs is None:
140             attrs = dict()
141         new_node = self.add_node(attrs)
142         # Missed careful handling of debug information
143         for i, inp in enumerate(inputs):
144             edge_attr = {'in': i, 'out': inp[1],
145                          'in_attrs': ['in', 'permutation'],
146                          'out_attrs': ['out', 'permutation'],
147                          'data_attrs': []} if not inp[0].has_valid('kind') or inp[0].kind == 'op' \
148                 else {'in': i, 'in_attrs': ['in', 'permutation']}
149             if edge_attrs is not None:
150                 edge_attr.update(edge_attrs)
151             self.graph.add_edge(inp[0].id, new_node.id, **edge_attr)
152         return new_node
153
154     def create_node_with_data(self, inputs: list = None, attrs: dict = None,
155                               data_nodes: [Node, np.ndarray, list] = None, edge_attrs: list = None):
156         """
157         Creates a new node with given inputs and attrs and also creates data node that
158         holds the op output value. Inputs should be data nodes (not op nodes).
159         Work for ops with a single output port only.
160         Edge attributes in edge_attrs go in order of items in 'inputs'
161         """
162         if inputs is None:
163             inputs = []
164         if attrs is None:
165             attrs = {}
166         # No need to extract port, because input node should be a data node,
167         # so there is no choice.
168         new_op_node = self.add_node(attrs)
169
170         # TODO Preserve debug infor
171         inputs_with_edge_attrs = []
172         for i, inp in enumerate(inputs):
173             if inp is None:
174                 continue
175             edge_attr = {'in': i}
176             if edge_attrs is not None and i < len(edge_attrs):
177                 edge_attr.update(edge_attrs[i])
178             inputs_with_edge_attrs.append((inp.id, new_op_node.id, edge_attr))
179         
180         self.graph.add_edges_from(inputs_with_edge_attrs)
181         
182         # TODO: Extend to the case when multiple output ports
183         old_data_value = [None]
184         old_data_shape = [None]
185         if data_nodes is None:
186             data_node = self.graph.unique_id()
187             self.graph.add_node(data_node, **add_attrs_props(
188                 dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
189                      infer=None)))
190             data_nodes = [Node(self.graph, data_node)]
191         else:
192             if type(data_nodes) not in [list, np.ndarray]:
193                 data_nodes = [data_nodes]
194             old_data_value = [data_node.value.copy() if data_node.has_valid('value') else None for data_node in
195                               data_nodes]
196             old_data_shape = [data_node.shape.copy() if data_node.has_valid('shape') else None for data_node in
197                               data_nodes]
198         for id, data_node in enumerate(data_nodes):
199             self.graph.add_edges_from([(new_op_node.id, data_node.id, {'out': id})])
200
201         if new_op_node.has_valid('infer'):
202             if log.getLogger().isEnabledFor(log.DEBUG):
203                 log.debug('Start running infer function for individual op node with attributes: {}'
204                           ''.format(str(new_op_node)))
205             new_op_node.infer(new_op_node)
206             assert all(old_value is None for old_value in old_data_value) or all(
207                 [np.array_equal(old_data_value[id], data_node.value) for id, data_node in enumerate(data_nodes)])
208             assert all(old_shape is None for old_shape in old_data_shape) or all(
209                 [np.array_equal(old_data_shape[id], data_node.shape) for id, data_node in enumerate(data_nodes)]), \
210                 "After re-inference of {} node, old and new shapes do not match. Old shapes: {}, new shapes: {}.".format(
211                     new_op_node.soft_get('name'),
212                     [old_data_shape[id] for id in range(len(data_nodes))],
213                     [data_node.shape for data_node in data_nodes])
214             for data_node in data_nodes:
215                 if log.getLogger().isEnabledFor(log.DEBUG):
216                     log.debug(
217                         'Finished running infer function, data nodes attributes: {}'.format(data_node))
218         return data_nodes[0] if len(data_nodes) == 1 else data_nodes
219
220     @staticmethod
221     def create_data_node(graph: Graph, op_node: Node, attrs: dict = None, edge_attrs: dict = None, out_port=0):
222         assert op_node is not None and op_node.kind == 'op'
223         assert len(op_node.out_nodes()) == 0
224         if attrs is None:
225             attrs = {}
226
227         data_node = graph.unique_id(op_node.id)
228         defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
229                             infer=None)
230         defaul_attrs.update(attrs)
231         graph.add_node(data_node, **add_attrs_props(defaul_attrs))
232         data_node = Node(graph, data_node)
233         if edge_attrs is not None:
234             graph.add_edges_from([(op_node.id, data_node.id, {'out': out_port, **edge_attrs})])
235         else:
236             graph.add_edges_from([(op_node.id, data_node.id, {'out': out_port})])
237         return data_node
238
239     @staticmethod
240     def _create_data_node(graph: Graph, name: str, attrs: dict = None):
241         if attrs is None:
242             attrs = {}
243
244         data_node = graph.unique_id(name)
245         defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
246                             infer=None)
247         defaul_attrs.update(attrs)
248         graph.add_node(data_node, **add_attrs_props(defaul_attrs))
249         data_node = Node(graph, data_node)
250         return data_node
251
252     @staticmethod
253     def create_input_data_node(graph: Graph, name: str, value: np.array, attrs: dict = {}):
254         data_node = graph.unique_id(name)
255         defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=np.array(value),
256                             shape=np.array(value.shape),
257                             data_type=None, infer=None)
258         defaul_attrs.update(attrs)
259         graph.add_node(data_node, **add_attrs_props(defaul_attrs))
260         return Node(graph, data_node)
261
262     @staticmethod
263     def create_and_connect_input_data_node(graph: Graph, op_node: Node, attrs: dict = None, edge_attrs: dict = None):
264         assert op_node is not None and op_node.kind == 'op'
265         if attrs is None:
266             attrs = {}
267         if edge_attrs is None:
268             edge_attrs = {}
269
270         data_node = graph.unique_id(op_node.id)
271         defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
272                             infer=None)
273         defaul_attrs.update(attrs)
274         graph.add_node(data_node, **add_attrs_props(defaul_attrs))
275         data_node = Node(graph, data_node)
276         graph.add_edges_from([(data_node.id, op_node.id, edge_attrs)])
277         return data_node
278
279     def update_node(self, node: Node, attrs: dict = None):
280         """
281         Updates/creates new attributes in node based on self.attrs and attrs.
282         """
283         new_attrs = {}
284         new_attrs.update(self.attrs)
285         if attrs:
286             new_attrs.update(attrs)
287         new_attrs = add_attrs_props(new_attrs)
288         update_ie_fields(new_attrs, self.ir_version)
289         self.substitute_ie_attrs(new_attrs)
290         for k, v in new_attrs.items():
291             node[k] = v
292
293     @classmethod
294     def update_node_stat(cls, node: Node, attrs: dict = None):
295         if attrs is None:
296             attrs = dict()
297         op = cls(node.graph, attrs)
298         op.update_node(node)
299
300     def supported_attrs(self):
301         """
302         Attributes that user should/can set for the operation
303         """
304         return []
305
306     def backend_attrs(self):
307         """
308         Attributes that will be translated to back-end IR
309         """
310         return self.supported_attrs()
311
312     def backend_attrs_v2(self):
313         return self.backend_attrs()
314
315     @staticmethod
316     def get_op_class_by_name(name: str):
317         return __class__.registered_ops[name]
318
319     @classmethod
320     def class_type(cls):
321         return class_registration.ClassType.OP
322
323     @staticmethod
324     def expand_node_shape(node: Node, dims_to_add):
325         if node is None or not node.has_valid('value'):
326             return
327         for idx in range(dims_to_add):
328             node.value = np.expand_dims(node.value, axis=-1)
329         node.shape = np.array(node.value.shape)
330
331
332 class PermuteAttrs:
333     Permutation = namedtuple('Permutation', ['perm', 'inv'])
334     Attr = namedtuple('Attr', ['name', 'port', 'func'])
335
336     common_permutation = lambda node, permutation, attr: node[attr][permutation.perm]
337     common_permutation_inv = lambda node, permutation, attr: permutation.inv[node[attr]]
338
339     # List of default permutations
340     common_attrs_permutation = {
341             'dim': common_permutation,
342             'pad': common_permutation,
343             'pads': common_permutation,
344             'shape': common_permutation,
345             'order': lambda node, permutation, attr: permutation.inv[node[attr][permutation.perm]],
346             'stride': common_permutation,
347             'window': common_permutation,
348             'dilation': common_permutation,
349             'kernel_shape': common_permutation,
350             'output_shape': common_permutation,
351             'slices': common_permutation,
352             'shrink_axis_mask': common_permutation,
353             'new_axis_mask': common_permutation,
354
355             'axis': common_permutation_inv,
356             'batch_dims': common_permutation_inv,
357             'channel_dims': common_permutation_inv,
358             'spatial_dims': common_permutation_inv,
359
360             'input_channel_dim': common_permutation_inv,
361             'output_channel_dim': common_permutation_inv,
362             'kernel_spatial_idx': common_permutation_inv,
363             'input_feature_channel': common_permutation_inv,
364             'output_feature_channel': common_permutation_inv,
365     }
366
367     @staticmethod
368     def __attr(name, port, func=None):
369         if func is None:
370             if name in PermuteAttrs.common_attrs_permutation:
371                 func = PermuteAttrs.common_attrs_permutation[name]
372             else:
373                 raise Error('Attr {} is missing in PermuteAttrs.common_attrs_permutation. Please update '
374                             'common_attrs_permutation with permutation for your attribute!'.format(name))
375
376         if len(port.split(':')) != 2 or port.split(':')[0] not in ['input', 'output']:
377             raise Error("Attribute port {} for {} wasn't set correctly!".format(port, name))
378
379         return PermuteAttrs.Attr(name=name, port=port, func=func)
380
381     def __init__(self):
382         self.attrs = {}
383
384     def update_attrs(self, attrs):
385         for attr in attrs:
386             if not isinstance(attr, tuple) or len(attr) not in [2, 3]:
387                 raise Error('attr object must be a tuple: (attribute_name, port) or (attribute_name, port, func)')
388             self.attrs.update({attr[0]: self.__attr(*attr)})
389         return self
390
391     def permute_attrs(self, node):
392         # This function applies permutation for given node
393         for attr in self.attrs.keys():
394             name, port, func = self.attrs[attr]
395             node_type, port = port.split(':')
396             port = int(port)
397             node_with_permutation = node.in_node(port) if node_type == 'input' else node.out_node(port)
398
399             if node_with_permutation.has_valid('permutation'):
400                 permutation = node_with_permutation.permutation
401                 if isinstance(permutation, type(lambda: 0)):
402                     node[name] = func(node, permutation(node), name)
403                 else:
404                     node[name] = func(node, permutation, name)
405
406     @staticmethod
407     def create_permute_attrs(node, attrs=None):
408         # Create permute_attrs if not exists
409         if not node.has_valid('permute_attrs'):
410             node['permute_attrs'] = PermuteAttrs()
411         node['permute_attrs'].update_attrs(attrs)
412
413     @staticmethod
414     def set_permutation(node1, node2, permutation):
415         # This function creates permutation on edge between node1->node2
416         edge_attrs = node1.graph.get_edge_data(node1.id, node2.id)[0]
417         if 'permutation' not in edge_attrs:
418             nx.set_edge_attributes(G=node1.graph,
419                                    values={(node1.id, node2.id, 0): permutation},
420                                    name='permutation')
421         else:
422             # If permutation exists we check that given and already set permutations are equal
423             if (edge_attrs['permutation'] is None and permutation is not None) or \
424                 not np.array_equal(edge_attrs['permutation'], permutation):
425                 raise Error('Permutation already exists in edge between {} and {}'.format(node1.name, node2.name))
426
427     @staticmethod
428     def get_inverse_permutation(perm):
429         inv = [0] * len(perm)
430         # Create reverse permutation
431         for index, pos in enumerate(perm):
432             inv[pos] = index
433         return inv
434
435     @staticmethod
436     def get_nhwc_to_nchw_permutation(dims_number: int):
437         # This function returns permutation from NHWC to NCHW for given dims number
438         if dims_number != 3:
439             perm = [0, dims_number - 1, *[x for x in range(1, dims_number - 1)]] if dims_number > 1 else [x for x in range(
440                 dims_number)]
441         else:
442             # Exclude 3D shapes from permutation process: identity permutation
443             perm = list(range(0, dims_number))
444         inv = PermuteAttrs.get_inverse_permutation(perm)
445         return PermuteAttrs.Permutation(perm=np.array(perm), inv=np.array(inv))
446
447     @staticmethod
448     def get_nchw_to_nhwc_permutation(dims_number: int):
449         # This function returns permutation from NCHW to NHWC for given dims number
450         if dims_number != 3:
451             perm = [0, *[x for x in range(2, dims_number)], 1] if dims_number > 1 else [x for x in range(
452                 dims_number)]
453         else:
454             # Exclude 3D shapes from permutation process: identity permutation
455             perm = list(range(0, dims_number))
456         inv = PermuteAttrs.get_inverse_permutation(perm)
457         return PermuteAttrs.Permutation(perm=np.array(perm), inv=np.array(inv))