Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / caffe / custom_layers_mapping.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import logging as log
17 from builtins import AttributeError
18 from defusedxml import ElementTree
19
20 from mo.front.caffe.collect_attributes import collect_attributes
21 from mo.front.caffe.extractor import node_pb_arg
22 from mo.front.common.register_custom_ops import check_for_duplicates, add_or_override_extractor
23
24
25 def expected_attribs(layer_attrs: list, attrs: list, fileName: str):
26     missing = [attr for attr in attrs if attr not in layer_attrs]
27     if len(missing):
28         layer = "layer {}".format(layer_attrs['NativeType']) if 'NativeType' in layer_attrs else "one of the layers"
29         log.error('Missing required attribute(s) {} for {} in {}. Skipped.'.format(', '.join(missing), layer, fileName))
30         return False
31     return True
32
33
34 def load_layers_xml(fileName: str):
35     try:
36         xml = ElementTree.parse(fileName).getroot()
37     except:
38         return {}
39
40     layers_map = {}
41     for child in xml:
42         if child.tag == 'CustomLayer':
43             if expected_attribs(child.attrib, ['NativeType', 'hasParam'], fileName):
44                 layer = child.attrib['NativeType']
45                 if layer in layers_map:
46                     log.error('Duplicated layer definition in {} for NativeType = {}. Skipped.'.format(fileName, layer))
47                 else:
48                     has_param = child.attrib['hasParam'].lower()
49                     if has_param == 'true' and expected_attribs(child.attrib, ['protoParamName'],
50                                                                 fileName) or has_param == 'false':
51                         layers_map[layer] = child.attrib
52                     else:
53                         log.error(
54                             'Cannot recognize {} value for hasParam for layer {}. Should be true or false. Skipped.'.format(
55                                 child.attrib['hasParam'], layer))
56
57         else:
58             log.error('Unexpected "{}" tag in {}. Should be CustomLayer. Skipped.'.format(child.tag, fileName))
59     return layers_map
60
61
62 special_keys = ['id', 'name', 'precision', 'type', 'layer', 'value', 'shape', 'op', 'kind', 'infer']
63
64 obfuscation_counter = 0
65
66
67 def new_obfuscated_key(attrs: dict, key: str):
68     global obfuscation_counter
69     while True:
70         new_key = key + str(obfuscation_counter)
71         obfuscation_counter += 1
72         if new_key not in attrs and new_key not in special_keys:
73             return new_key
74
75
76 def obfuscate_attr_key(attrs: dict, key: str, keys: list):
77     """
78     Replace attribute with key by another key that is not in
79     special_keys list and do not match other attributes.
80     """
81     if key not in attrs or key not in special_keys:
82         return
83
84     new_key = new_obfuscated_key(attrs, key)
85     assert new_key not in attrs
86     assert new_key not in keys
87     attrs[new_key] = attrs[key]
88     del attrs[key]
89     key_index = keys.index(key)
90     keys[key_index] = (key, new_key)
91     log.debug('Obfuscated attribute name {} to {}'.format(key, new_key))
92
93
94 def obfuscate_special_attrs(attrs: dict, keys: list):
95     for key in special_keys:
96         obfuscate_attr_key(attrs, key, keys)
97
98
99 def proto_extractor(pb, model_pb, mapping, disable_omitting_optional, enable_flattening_nested_params):
100     log.info("Custom extractor for layer {} with mapping {}".format(pb.type, mapping))
101     log.debug('Found custom layer {}. Params are processed'.format(pb.name))
102     if mapping['hasParam'].lower() != 'true':
103         return {}
104     try:
105         native_attr = collect_attributes(getattr(pb, mapping['protoParamName']),
106                                          disable_omitting_optional=disable_omitting_optional,
107                                          enable_flattening_nested_params=enable_flattening_nested_params)
108     except AttributeError as e:
109         error_message = 'Layer {} has no attribute {}'.format(pb.type, str(e).split(' ')[-1])
110         log.error(error_message)
111         raise ValueError(error_message)
112     keys = list(native_attr.keys())
113     obfuscate_special_attrs(native_attr, keys)
114     # avoid 'mo_caffe' appearing in param
115     for attr in native_attr:
116         if 'mo_caffe' in native_attr[attr]:
117             native_attr[attr] = native_attr[attr].replace('mo_caffe', 'caffe')
118     log.debug(str(keys))
119     log.debug(str(native_attr))
120
121     attrs = {
122         'IE': [(
123             'layer',
124             [('id', lambda node: node.id), 'name', 'precision', 'type'],
125             [
126                 ('data', keys, []),
127                 '@ports',
128                 '@consts'])]}
129     attrs.update(native_attr)
130     return attrs
131
132
133 def update_extractors(extractors, layers_map, disable_omitting_optional, enable_flattening_nested_params):
134     keys = check_for_duplicates(extractors)
135     for layer, attrs in layers_map.items():
136         add_or_override_extractor(
137             extractors,
138             keys,
139             layer,
140             (
141                 lambda l: node_pb_arg(
142                     lambda pb, model_pb: proto_extractor(
143                         pb, model_pb, l, disable_omitting_optional, enable_flattening_nested_params
144                     )
145                 )
146             )(layers_map[layer]),
147             'custom layer {} from custom layers mapping xml file'.format(layer)
148         )
149     check_for_duplicates(extractors)