Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / class_registration.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 import os
19 from enum import Enum
20
21 import networkx as nx
22
23 from mo.graph.graph import Graph
24 from mo.middle.passes.eliminate import graph_clean_up_tf, graph_clean_up_onnx, graph_clean_up
25 from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
26 from mo.utils.error import Error
27 from mo.utils.utils import refer_to_faq_msg
28
29 _registered_classes_dict = {}
30
31
32 def _check_unique_ids():
33     """
34     Check that idxs is unique for all registered replacements.
35     """
36     unique_idxs = set()
37     for class_type, classes_set in _registered_classes_dict.items():
38         for cls in classes_set:
39             replacers = [c for c in cls.registered_cls if not hasattr(c, 'op')] + \
40                         [c for op, c in cls.registered_ops.items() if c]
41             for replacer_cls in replacers:
42                 if hasattr(replacer_cls, 'id'):
43                     id_cls = getattr(replacer_cls, 'id')
44
45                     if id_cls in unique_idxs:
46                         raise Error('Found replacer {} with not unique id!'.format(replacer_cls))
47                     unique_idxs.add(id_cls)
48     log.debug("All replacers has unique idxs.")
49
50
51 def get_enabled_and_disabled_transforms():
52     """
53     :return: tuple of lists with force enabled and disabled id of transformations.
54     """
55     disabled_transforms = os.environ['MO_DISABLED_TRANSFORMS'] if 'MO_DISABLED_TRANSFORMS' in os.environ else ''
56     enabled_transforms = os.environ['MO_ENABLED_TRANSFORMS'] if 'MO_ENABLED_TRANSFORMS' in os.environ else ''
57
58     assert isinstance(enabled_transforms, str)
59     assert isinstance(disabled_transforms, str)
60
61     disabled_transforms = disabled_transforms.split(',')
62     enabled_transforms = enabled_transforms.split(',')
63
64     return enabled_transforms, disabled_transforms
65
66
67 class ClassType(Enum):
68     EXTRACTOR = 0
69     OP = 1
70     FRONT_REPLACER = 2
71     MIDDLE_REPLACER = 3
72     BACK_REPLACER = 4
73
74
75 def _update(cls, registered_list: list, registered_dict: dict, key: str, enabled_transforms: list, disabled_transforms: list):
76     new_keys = {}  # maps a custom name to class
77     new_keys_lower = {}  # translates lowered custom name to its original form
78     # print('Registering new subclasses for', cls)
79
80     for c in cls.__subclasses__():
81         # Force enabling operations
82         if hasattr(c, 'id') and c.id in enabled_transforms:
83             setattr(c, 'enabled', True)
84
85         # Force disabling operations
86         if hasattr(c, 'id') and c.id in disabled_transforms:
87             setattr(c, 'enabled', False)
88
89         if c not in registered_list and (not hasattr(c, 'enabled') or c.enabled):
90             if hasattr(cls, 'excluded_classes') and c in cls.excluded_classes:
91                 continue
92             registered_list.append(c)
93             log.info('New subclass: {}'.format(c))
94             if hasattr(c, key):
95                 k = getattr(c, key)
96                 if k.lower() in new_keys_lower:
97                     raise Error(
98                         'Attempt to register of custom name {} for the second time as class {}. ' \
99                         'Note that custom names are case-insensitive. ' +
100                         refer_to_faq_msg(55), k, c)
101                 else:
102                     new_keys_lower[k.lower()] = k
103                     new_keys[k] = c
104                     log.info('Registered a new subclass with key: {}'.format(k))
105         else:
106             log.warning('Skipped {} registration because it was already registered or it was disabled. '.format(c))
107     registered_dict.update(new_keys)
108
109
110 def update_registration(classes: list, enabled_transforms: list, disabled_transforms: list):
111     for cls in classes:
112         _update(cls, cls.registered_cls, cls.registered_ops, 'op', enabled_transforms, disabled_transforms)
113         _registered_classes_dict.setdefault(cls.class_type(), set()).add(cls)
114
115
116 def apply_replacements(graph: Graph, replacements_type):
117     """
118     Apply all patterns that do not have 'op' first, then apply patterns from registered_ops.
119     If two or more classes replaces the same op (both have op class attribute and values match), such
120     pattern is not applied (while registration it will warn user that we have a conflict).
121     """
122     dependency_graph = Graph()
123     for class_type, classes_set in _registered_classes_dict.items():
124         if class_type == replacements_type:
125             for cls in classes_set:
126                 replacers = [c for c in cls.registered_cls if not hasattr(c, 'op')] + \
127                             [c for op, c in cls.registered_ops.items() if c]
128                 for replacer_cls in replacers:
129                     if replacer_cls in cls.excluded_replacers:
130                         # skip infrastructure classes
131                         continue
132
133                     dependency_graph.add_node(replacer_cls)
134                     for cls_after in replacer_cls().run_before():
135                         log.debug("Replacer {} will be run before {}".format(replacer_cls, cls_after))
136                         dependency_graph.add_edge(replacer_cls, cls_after)
137                     for cls_before in replacer_cls().run_after():
138                         log.debug("Replacer {} will be run after {}".format(replacer_cls, cls_before))
139                         dependency_graph.add_edge(cls_before, replacer_cls)
140
141     try:
142         replacers_order = list(nx.topological_sort(dependency_graph))
143     except nx.NetworkXUnfeasible as exception:
144         cycles = nx.simple_cycles(dependency_graph)
145         raise Error('There is(are) cyclic dependency(ies) between replacers. One of the cycles is the following: {}',
146                     ' -> '.join([str(node) for node in list(cycles)[0]])) from exception
147
148     for replacer_cls in replacers_order:
149         replacer = replacer_cls()
150
151         replacement_id = 'REPLACEMENT_ID'
152         if hasattr(replacer, 'replacement_id'):
153             replacement_id = replacer.replacement_id
154
155         if hasattr(replacer, 'enabled') and not replacer.enabled:
156             log.info("Skip replacer {} (enabled = False)".format(replacer_cls))
157             continue
158
159         if hasattr(replacer, 'graph_condition') and \
160                 not all([condition(graph) for condition in replacer.graph_condition]):
161             log.info("Skip replacer {} (graph_condition not satisfied)".format(replacer_cls))
162             continue
163
164         log.debug("Run replacer {}".format(replacer_cls))
165
166         try:
167             replacer.find_and_replace_pattern(graph)
168
169             if hasattr(replacer, 'force_clean_up') and replacer.force_clean_up:
170                 for_graph_and_each_sub_graph_recursively(
171                     graph,
172                     graph_clean_up_tf if graph.graph['fw'] == 'tf' else
173                     graph_clean_up_onnx if graph.graph['fw'] == 'onnx' else
174                     graph_clean_up)
175
176             for_graph_and_each_sub_graph_recursively(graph, lambda _: graph.check_empty_graph(replacer_cls))
177             for_graph_and_each_sub_graph_recursively(graph, lambda _: graph.check_shapes_consistency())
178
179         except Error as err:
180             raise Error('Exception occurred during running replacer "{}" ({}): {}'.format(
181                 replacement_id,
182                 replacer_cls,
183                 str(err).replace('[REPLACEMENT_ID]', replacement_id),
184             )) from err
185         except Exception as err:
186             raise Exception('Exception occurred during running replacer "{} ({})": {}'.format(
187                 replacement_id,
188                 replacer_cls,
189                 str(err).replace('[REPLACEMENT_ID]', replacement_id),
190             )) from err