2 Copyright (c) 2017-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from collections import defaultdict
20 from mo.front.extractor import FrontExtractorOp
21 from mo.ops.op import Op
22 from mo.utils.error import Error
23 from mo.utils.utils import refer_to_faq_msg
26 def extension_extractor(node, ex_cls, disable_omitting_optional: bool = False,
27 enable_flattening_optional_params: bool = False):
29 supported = ex.extract(node)
30 return node.graph.node[node.id] if supported else None
33 def extension_op_extractor(node, op_cls):
34 op_cls.update_node_stat(node)
35 # TODO Need to differentiate truly supported ops extractors and ops extractors generated here
36 return node.graph.node[node.id]
39 def find_case_insensitive_duplicates(extractors_collection: dict):
41 Searches for case-insensitive duplicates among extractors_collection keys.
42 Returns a list of groups, where each group is a list of case-insensitive duplicates.
43 Also returns a dictionary with lowered keys.
45 keys = defaultdict(list)
46 for k in extractors_collection.keys():
47 keys[k.lower()].append(k)
48 return [duplicates for duplicates in keys.values() if len(duplicates) > 1], keys
51 def check_for_duplicates(extractors_collection: dict):
53 Check if extractors_collection has case-insensitive duplicates, if it does,
54 raise exception with information about duplicates
56 # Check if extractors_collection is a normal form, that is it doesn't have case-insensitive duplicates
57 duplicates, keys = find_case_insensitive_duplicates(extractors_collection)
58 if len(duplicates) > 0:
59 raise Error('Extractors collection have case insensitive duplicates {}. ' +
60 refer_to_faq_msg(47), duplicates)
61 return {k: v[0] for k, v in keys.items()}
64 def add_or_override_extractor(extractors: dict, keys: dict, name, extractor, extractor_desc):
65 name_lower = name.lower()
66 if name_lower in keys:
67 old_name = keys[name_lower]
68 assert old_name in extractors
69 del extractors[old_name]
70 log.debug('Overridden extractor entry {} by {}.'.format(old_name, extractor_desc))
72 log.debug('Extractor entry {} was changed to {}.'.format(old_name, name))
74 log.debug('Added a new entry {} to extractors with {}.'.format(name, extractor_desc))
75 # keep extractor name in case-sensitive form for better diagnostics for the user
76 # but we will continue processing of extractor keys in case-insensitive way
77 extractors[name] = extractor
78 keys[name_lower] = name
81 def update_extractors_with_extensions(extractors_collection: dict = None,
82 disable_omitting_optional: bool = False,
83 enable_flattening_optional_params: bool = False):
85 Update tf_op_extractors based on mnemonics registered in Op and FrontExtractorOp.
86 FrontExtractorOp extends and overrides default extractors.
87 Op extends but doesn't override extractors.
89 keys = check_for_duplicates(extractors_collection)
90 for op, ex_cls in FrontExtractorOp.registered_ops.items():
91 add_or_override_extractor(
92 extractors_collection,
95 lambda node, cls=ex_cls: extension_extractor(
96 node, cls, disable_omitting_optional, enable_flattening_optional_params),
97 'custom extractor class {}'.format(ex_cls)
100 for op, op_cls in Op.registered_ops.items():
101 op_lower = op.lower()
102 if op_lower not in keys:
103 extractors_collection[op] = (lambda c: lambda node: extension_op_extractor(node, c))(op_cls)
104 log.debug('Added a new entry {} to extractors with custom op class {}.'.format(op, op_cls))
106 check_for_duplicates(extractors_collection)