2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
23 from mo.graph.graph import Graph
24 from mo.middle.passes.eliminate import graph_clean_up_tf, graph_clean_up_onnx, graph_clean_up
25 from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
26 from mo.utils.error import Error
27 from mo.utils.utils import refer_to_faq_msg
29 _registered_classes_dict = {}
32 def _check_unique_ids():
34 Check that idxs is unique for all registered replacements.
37 for class_type, classes_set in _registered_classes_dict.items():
38 for cls in classes_set:
39 replacers = [c for c in cls.registered_cls if not hasattr(c, 'op')] + \
40 [c for op, c in cls.registered_ops.items() if c]
41 for replacer_cls in replacers:
42 if hasattr(replacer_cls, 'id'):
43 id_cls = getattr(replacer_cls, 'id')
45 if id_cls in unique_idxs:
46 raise Error('Found replacer {} with not unique id!'.format(replacer_cls))
47 unique_idxs.add(id_cls)
48 log.debug("All replacers has unique idxs.")
51 def get_enabled_and_disabled_transforms():
53 :return: tuple of lists with force enabled and disabled id of transformations.
55 disabled_transforms = os.environ['MO_DISABLED_TRANSFORMS'] if 'MO_DISABLED_TRANSFORMS' in os.environ else ''
56 enabled_transforms = os.environ['MO_ENABLED_TRANSFORMS'] if 'MO_ENABLED_TRANSFORMS' in os.environ else ''
58 assert isinstance(enabled_transforms, str)
59 assert isinstance(disabled_transforms, str)
61 disabled_transforms = disabled_transforms.split(',')
62 enabled_transforms = enabled_transforms.split(',')
64 return enabled_transforms, disabled_transforms
67 class ClassType(Enum):
75 def _update(cls, registered_list: list, registered_dict: dict, key: str, enabled_transforms: list, disabled_transforms: list):
76 new_keys = {} # maps a custom name to class
77 new_keys_lower = {} # translates lowered custom name to its original form
78 # print('Registering new subclasses for', cls)
80 for c in cls.__subclasses__():
81 # Force enabling operations
82 if hasattr(c, 'id') and c.id in enabled_transforms:
83 setattr(c, 'enabled', True)
85 # Force disabling operations
86 if hasattr(c, 'id') and c.id in disabled_transforms:
87 setattr(c, 'enabled', False)
89 if c not in registered_list and (not hasattr(c, 'enabled') or c.enabled):
90 if hasattr(cls, 'excluded_classes') and c in cls.excluded_classes:
92 registered_list.append(c)
93 log.info('New subclass: {}'.format(c))
96 if k.lower() in new_keys_lower:
98 'Attempt to register of custom name {} for the second time as class {}. ' \
99 'Note that custom names are case-insensitive. ' +
100 refer_to_faq_msg(55), k, c)
102 new_keys_lower[k.lower()] = k
104 log.info('Registered a new subclass with key: {}'.format(k))
106 log.warning('Skipped {} registration because it was already registered or it was disabled. '.format(c))
107 registered_dict.update(new_keys)
110 def update_registration(classes: list, enabled_transforms: list, disabled_transforms: list):
112 _update(cls, cls.registered_cls, cls.registered_ops, 'op', enabled_transforms, disabled_transforms)
113 _registered_classes_dict.setdefault(cls.class_type(), set()).add(cls)
116 def apply_replacements(graph: Graph, replacements_type):
118 Apply all patterns that do not have 'op' first, then apply patterns from registered_ops.
119 If two or more classes replaces the same op (both have op class attribute and values match), such
120 pattern is not applied (while registration it will warn user that we have a conflict).
122 dependency_graph = Graph()
123 for class_type, classes_set in _registered_classes_dict.items():
124 if class_type == replacements_type:
125 for cls in classes_set:
126 replacers = [c for c in cls.registered_cls if not hasattr(c, 'op')] + \
127 [c for op, c in cls.registered_ops.items() if c]
128 for replacer_cls in replacers:
129 if replacer_cls in cls.excluded_replacers:
130 # skip infrastructure classes
133 dependency_graph.add_node(replacer_cls)
134 for cls_after in replacer_cls().run_before():
135 log.debug("Replacer {} will be run before {}".format(replacer_cls, cls_after))
136 dependency_graph.add_edge(replacer_cls, cls_after)
137 for cls_before in replacer_cls().run_after():
138 log.debug("Replacer {} will be run after {}".format(replacer_cls, cls_before))
139 dependency_graph.add_edge(cls_before, replacer_cls)
142 replacers_order = list(nx.topological_sort(dependency_graph))
143 except nx.NetworkXUnfeasible as exception:
144 cycles = nx.simple_cycles(dependency_graph)
145 raise Error('There is(are) cyclic dependency(ies) between replacers. One of the cycles is the following: {}',
146 ' -> '.join([str(node) for node in list(cycles)[0]])) from exception
148 for replacer_cls in replacers_order:
149 replacer = replacer_cls()
151 replacement_id = 'REPLACEMENT_ID'
152 if hasattr(replacer, 'replacement_id'):
153 replacement_id = replacer.replacement_id
155 if hasattr(replacer, 'enabled') and not replacer.enabled:
156 log.info("Skip replacer {} (enabled = False)".format(replacer_cls))
159 if hasattr(replacer, 'graph_condition') and \
160 not all([condition(graph) for condition in replacer.graph_condition]):
161 log.info("Skip replacer {} (graph_condition not satisfied)".format(replacer_cls))
164 log.debug("Run replacer {}".format(replacer_cls))
167 replacer.find_and_replace_pattern(graph)
169 if hasattr(replacer, 'force_clean_up') and replacer.force_clean_up:
170 for_graph_and_each_sub_graph_recursively(
172 graph_clean_up_tf if graph.graph['fw'] == 'tf' else
173 graph_clean_up_onnx if graph.graph['fw'] == 'onnx' else
176 for_graph_and_each_sub_graph_recursively(graph, lambda _: graph.check_empty_graph(replacer_cls))
177 for_graph_and_each_sub_graph_recursively(graph, lambda _: graph.check_shapes_consistency())
180 raise Error('Exception occurred during running replacer "{}" ({}): {}'.format(
183 str(err).replace('[REPLACEMENT_ID]', replacement_id),
185 except Exception as err:
186 raise Exception('Exception occurred during running replacer "{} ({})": {}'.format(
189 str(err).replace('[REPLACEMENT_ID]', replacement_id),