Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / caffe / collect_attributes.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import logging as log
17
18
19 def cast_to_string(descriptor, value):
20     if descriptor.type != descriptor.TYPE_BOOL:
21         return str(value)
22     return str(int(value))
23
24
25 def append_unique(attrs, new_attr, value):
26     if new_attr in attrs:
27         log.error('The parameter {} overwrites already existing value. '.format(new_attr) +
28                   'This happens due to flattening nested parameters. ' +
29                   'Use enable_flattening_nested_params to flatten nesting')
30     return {new_attr: value}
31
32
33 def append_unique_enum(attrs: dict, descriptor, value):
34     enum_name = '{}.{}'.format(
35         descriptor.enum_type.full_name.rsplit('.', 1)[0],  # remove enum name Z from X.Y.Z name
36         descriptor.enum_type.values[value].name)
37     return append_unique(attrs, descriptor.name, str(enum_name))
38
39
40 def unrolled_name(descriptor_name: str, enable_flattening_nested_params: bool = False, prefix: str = '') -> str:
41     if not enable_flattening_nested_params:
42         return descriptor_name
43     elif prefix:
44         return '{}__{}'.format(prefix, descriptor_name)
45     return descriptor_name
46
47
48 def collect_optional_attributes(obj, prefix: str = '', disable_omitting_optional: bool = False,
49                                 enable_flattening_nested_params: bool = False):
50     """
51     Collect all optional attributes from protobuf message
52     Args:
53         attrs: dictionary with attributes
54         obj: protobuf message
55         prefix: prefix for this protobuf.message
56         disable_omitting_optional: disable omitting optional flag
57         enable_flattening_nested_params: disable flattening optional params flag
58     """
59     attrs = {}
60     fields = [field[0].name for field in obj.ListFields()]
61     for descriptor in obj.DESCRIPTOR.fields:
62         value = getattr(obj, descriptor.name)
63         name = unrolled_name(descriptor.name, enable_flattening_nested_params, prefix)
64         if descriptor.label != descriptor.LABEL_OPTIONAL:
65             continue
66         if (descriptor.has_default_value or disable_omitting_optional) or descriptor.name in fields:
67             if descriptor.type == descriptor.TYPE_MESSAGE:
68                 attrs.update(collect_optional_attributes(value,
69                                                          prefix=name,
70                                                          disable_omitting_optional=disable_omitting_optional,
71                                                          enable_flattening_nested_params=enable_flattening_nested_params))
72             elif descriptor.type == descriptor.TYPE_ENUM:
73                 attrs.update(append_unique_enum(attrs, descriptor, value))
74             else:
75                 attrs.update(append_unique(attrs, name, cast_to_string(descriptor, value)))
76     return attrs
77
78
79 def collect_attributes(obj, prefix: str = '', disable_omitting_optional: bool = False,
80                        enable_flattening_nested_params: bool = False):
81     """
82     Collect all attributes from protobuf message
83     Args:
84         attrs: dictionary with attributes
85         obj: protobuf message
86         prefix: prefix for this protobuf.message
87         disable_omitting_optional: disable omitting optional flag
88         enable_flattening_nested_params: disable flattening optional params flag
89     """
90     attrs = collect_optional_attributes(obj, prefix, disable_omitting_optional, enable_flattening_nested_params)
91     fields = [field[0].name for field in obj.ListFields()]
92     for descriptor in obj.DESCRIPTOR.fields:
93         value = getattr(obj, descriptor.name)
94         name = unrolled_name(descriptor.name, enable_flattening_nested_params, prefix)
95         if descriptor.label == descriptor.LABEL_REPEATED:
96             if descriptor.name not in fields:
97                 log.warning('Field {} was ignored'.format(descriptor.name))
98                 continue
99             if descriptor.type == descriptor.TYPE_MESSAGE:
100                 for x in value:
101                     attrs.update(collect_attributes(x, prefix=name))
102             else:
103                 attrs.update(append_unique(attrs, name, ",".join([str(v) for v in value])))
104         elif descriptor.label == descriptor.LABEL_REQUIRED:
105             if descriptor.type == descriptor.TYPE_MESSAGE:
106                 for x in value:
107                     attrs.update(collect_attributes(x, prefix=name))
108             else:
109                 attrs.update(append_unique(attrs, name, cast_to_string(descriptor, value)))
110     return attrs
111
112
113 def merge_attrs(param, update_attrs: dict):
114     all_attrs = collect_attributes(param)
115     mandatory_attrs = set(all_attrs.keys()).intersection(set(update_attrs.keys()))
116     return {value: update_attrs[value] for value in mandatory_attrs}