Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / register_custom_ops.py
1 """
2  Copyright (c) 2017-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 from collections import defaultdict
19
20 from mo.front.extractor import FrontExtractorOp
21 from mo.ops.op import Op
22 from mo.utils.error import Error
23 from mo.utils.utils import refer_to_faq_msg
24
25
26 def extension_extractor(node, ex_cls, disable_omitting_optional: bool = False,
27                         enable_flattening_optional_params: bool = False):
28     ex = ex_cls()
29     supported = ex.extract(node)
30     return node.graph.node[node.id] if supported else None
31
32
33 def extension_op_extractor(node, op_cls):
34     op_cls.update_node_stat(node)
35     # TODO Need to differentiate truly supported ops extractors and ops extractors generated here
36     return node.graph.node[node.id]
37
38
39 def find_case_insensitive_duplicates(extractors_collection: dict):
40     """
41     Searches for case-insensitive duplicates among extractors_collection keys.
42     Returns a list of groups, where each group is a list of case-insensitive duplicates.
43     Also returns a dictionary with lowered keys.
44     """
45     keys = defaultdict(list)
46     for k in extractors_collection.keys():
47         keys[k.lower()].append(k)
48     return [duplicates for duplicates in keys.values() if len(duplicates) > 1], keys
49
50
51 def check_for_duplicates(extractors_collection: dict):
52     """
53     Check if extractors_collection has case-insensitive duplicates, if it does,
54     raise exception with information about duplicates
55     """
56     # Check if extractors_collection is a normal form, that is it doesn't have case-insensitive duplicates
57     duplicates, keys = find_case_insensitive_duplicates(extractors_collection)
58     if len(duplicates) > 0:
59         raise Error('Extractors collection have case insensitive duplicates {}. ' +
60                     refer_to_faq_msg(47), duplicates)
61     return {k: v[0] for k, v in keys.items()}
62
63
64 def add_or_override_extractor(extractors: dict, keys: dict, name, extractor, extractor_desc):
65     name_lower = name.lower()
66     if name_lower in keys:
67         old_name = keys[name_lower]
68         assert old_name in extractors
69         del extractors[old_name]
70         log.debug('Overridden extractor entry {} by {}.'.format(old_name, extractor_desc))
71         if old_name != name:
72             log.debug('Extractor entry {} was changed to {}.'.format(old_name, name))
73     else:
74         log.debug('Added a new entry {} to extractors with {}.'.format(name, extractor_desc))
75     # keep extractor name in case-sensitive form for better diagnostics for the user
76     # but we will continue processing of extractor keys in case-insensitive way
77     extractors[name] = extractor
78     keys[name_lower] = name
79
80
81 def update_extractors_with_extensions(extractors_collection: dict = None,
82                                       disable_omitting_optional: bool = False,
83                                       enable_flattening_optional_params: bool = False):
84     """
85     Update tf_op_extractors based on mnemonics registered in Op and FrontExtractorOp.
86     FrontExtractorOp extends and overrides default extractors.
87     Op extends but doesn't override extractors.
88     """
89     keys = check_for_duplicates(extractors_collection)
90     for op, ex_cls in FrontExtractorOp.registered_ops.items():
91         add_or_override_extractor(
92             extractors_collection,
93             keys,
94             op,
95             lambda node, cls=ex_cls: extension_extractor(
96                 node, cls, disable_omitting_optional, enable_flattening_optional_params),
97             'custom extractor class {}'.format(ex_cls)
98         )
99
100     for op, op_cls in Op.registered_ops.items():
101         op_lower = op.lower()
102         if op_lower not in keys:
103             extractors_collection[op] = (lambda c: lambda node: extension_op_extractor(node, c))(op_cls)
104             log.debug('Added a new entry {} to extractors with custom op class {}.'.format(op, op_cls))
105             keys[op_lower] = op
106     check_for_duplicates(extractors_collection)