2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from builtins import AttributeError
18 from defusedxml import ElementTree
20 from mo.front.caffe.collect_attributes import collect_attributes
21 from mo.front.caffe.extractor import node_pb_arg
22 from mo.front.common.register_custom_ops import check_for_duplicates, add_or_override_extractor
25 def expected_attribs(layer_attrs: list, attrs: list, fileName: str):
26 missing = [attr for attr in attrs if attr not in layer_attrs]
28 layer = "layer {}".format(layer_attrs['NativeType']) if 'NativeType' in layer_attrs else "one of the layers"
29 log.error('Missing required attribute(s) {} for {} in {}. Skipped.'.format(', '.join(missing), layer, fileName))
34 def load_layers_xml(fileName: str):
36 xml = ElementTree.parse(fileName).getroot()
42 if child.tag == 'CustomLayer':
43 if expected_attribs(child.attrib, ['NativeType', 'hasParam'], fileName):
44 layer = child.attrib['NativeType']
45 if layer in layers_map:
46 log.error('Duplicated layer definition in {} for NativeType = {}. Skipped.'.format(fileName, layer))
48 has_param = child.attrib['hasParam'].lower()
49 if has_param == 'true' and expected_attribs(child.attrib, ['protoParamName'],
50 fileName) or has_param == 'false':
51 layers_map[layer] = child.attrib
54 'Cannot recognize {} value for hasParam for layer {}. Should be true or false. Skipped.'.format(
55 child.attrib['hasParam'], layer))
58 log.error('Unexpected "{}" tag in {}. Should be CustomLayer. Skipped.'.format(child.tag, fileName))
62 special_keys = ['id', 'name', 'precision', 'type', 'layer', 'value', 'shape', 'op', 'kind', 'infer']
64 obfuscation_counter = 0
67 def new_obfuscated_key(attrs: dict, key: str):
68 global obfuscation_counter
70 new_key = key + str(obfuscation_counter)
71 obfuscation_counter += 1
72 if new_key not in attrs and new_key not in special_keys:
76 def obfuscate_attr_key(attrs: dict, key: str, keys: list):
78 Replace attribute with key by another key that is not in
79 special_keys list and do not match other attributes.
81 if key not in attrs or key not in special_keys:
84 new_key = new_obfuscated_key(attrs, key)
85 assert new_key not in attrs
86 assert new_key not in keys
87 attrs[new_key] = attrs[key]
89 key_index = keys.index(key)
90 keys[key_index] = (key, new_key)
91 log.debug('Obfuscated attribute name {} to {}'.format(key, new_key))
94 def obfuscate_special_attrs(attrs: dict, keys: list):
95 for key in special_keys:
96 obfuscate_attr_key(attrs, key, keys)
99 def proto_extractor(pb, model_pb, mapping, disable_omitting_optional, enable_flattening_nested_params):
100 log.info("Custom extractor for layer {} with mapping {}".format(pb.type, mapping))
101 log.debug('Found custom layer {}. Params are processed'.format(pb.name))
102 if mapping['hasParam'].lower() != 'true':
105 native_attr = collect_attributes(getattr(pb, mapping['protoParamName']),
106 disable_omitting_optional=disable_omitting_optional,
107 enable_flattening_nested_params=enable_flattening_nested_params)
108 except AttributeError as e:
109 error_message = 'Layer {} has no attribute {}'.format(pb.type, str(e).split(' ')[-1])
110 log.error(error_message)
111 raise ValueError(error_message)
112 keys = list(native_attr.keys())
113 obfuscate_special_attrs(native_attr, keys)
114 # avoid 'mo_caffe' appearing in param
115 for attr in native_attr:
116 if 'mo_caffe' in native_attr[attr]:
117 native_attr[attr] = native_attr[attr].replace('mo_caffe', 'caffe')
119 log.debug(str(native_attr))
124 [('id', lambda node: node.id), 'name', 'precision', 'type'],
129 attrs.update(native_attr)
133 def update_extractors(extractors, layers_map, disable_omitting_optional, enable_flattening_nested_params):
134 keys = check_for_duplicates(extractors)
135 for layer, attrs in layers_map.items():
136 add_or_override_extractor(
141 lambda l: node_pb_arg(
142 lambda pb, model_pb: proto_extractor(
143 pb, model_pb, l, disable_omitting_optional, enable_flattening_nested_params
146 )(layers_map[layer]),
147 'custom layer {} from custom layers mapping xml file'.format(layer)
149 check_for_duplicates(extractors)