2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 def cast_to_string(descriptor, value):
20 if descriptor.type != descriptor.TYPE_BOOL:
22 return str(int(value))
25 def append_unique(attrs, new_attr, value):
27 log.error('The parameter {} overwrites already existing value. '.format(new_attr) +
28 'This happens due to flattening nested parameters. ' +
29 'Use enable_flattening_nested_params to flatten nesting')
30 return {new_attr: value}
33 def append_unique_enum(attrs: dict, descriptor, value):
34 enum_name = '{}.{}'.format(
35 descriptor.enum_type.full_name.rsplit('.', 1)[0], # remove enum name Z from X.Y.Z name
36 descriptor.enum_type.values[value].name)
37 return append_unique(attrs, descriptor.name, str(enum_name))
40 def unrolled_name(descriptor_name: str, enable_flattening_nested_params: bool = False, prefix: str = '') -> str:
41 if not enable_flattening_nested_params:
42 return descriptor_name
44 return '{}__{}'.format(prefix, descriptor_name)
45 return descriptor_name
48 def collect_optional_attributes(obj, prefix: str = '', disable_omitting_optional: bool = False,
49 enable_flattening_nested_params: bool = False):
51 Collect all optional attributes from protobuf message
53 attrs: dictionary with attributes
55 prefix: prefix for this protobuf.message
56 disable_omitting_optional: disable omitting optional flag
57 enable_flattening_nested_params: disable flattening optional params flag
60 fields = [field[0].name for field in obj.ListFields()]
61 for descriptor in obj.DESCRIPTOR.fields:
62 value = getattr(obj, descriptor.name)
63 name = unrolled_name(descriptor.name, enable_flattening_nested_params, prefix)
64 if descriptor.label != descriptor.LABEL_OPTIONAL:
66 if (descriptor.has_default_value or disable_omitting_optional) or descriptor.name in fields:
67 if descriptor.type == descriptor.TYPE_MESSAGE:
68 attrs.update(collect_optional_attributes(value,
70 disable_omitting_optional=disable_omitting_optional,
71 enable_flattening_nested_params=enable_flattening_nested_params))
72 elif descriptor.type == descriptor.TYPE_ENUM:
73 attrs.update(append_unique_enum(attrs, descriptor, value))
75 attrs.update(append_unique(attrs, name, cast_to_string(descriptor, value)))
79 def collect_attributes(obj, prefix: str = '', disable_omitting_optional: bool = False,
80 enable_flattening_nested_params: bool = False):
82 Collect all attributes from protobuf message
84 attrs: dictionary with attributes
86 prefix: prefix for this protobuf.message
87 disable_omitting_optional: disable omitting optional flag
88 enable_flattening_nested_params: disable flattening optional params flag
90 attrs = collect_optional_attributes(obj, prefix, disable_omitting_optional, enable_flattening_nested_params)
91 fields = [field[0].name for field in obj.ListFields()]
92 for descriptor in obj.DESCRIPTOR.fields:
93 value = getattr(obj, descriptor.name)
94 name = unrolled_name(descriptor.name, enable_flattening_nested_params, prefix)
95 if descriptor.label == descriptor.LABEL_REPEATED:
96 if descriptor.name not in fields:
97 log.warning('Field {} was ignored'.format(descriptor.name))
99 if descriptor.type == descriptor.TYPE_MESSAGE:
101 attrs.update(collect_attributes(x, prefix=name))
103 attrs.update(append_unique(attrs, name, ",".join([str(v) for v in value])))
104 elif descriptor.label == descriptor.LABEL_REQUIRED:
105 if descriptor.type == descriptor.TYPE_MESSAGE:
107 attrs.update(collect_attributes(x, prefix=name))
109 attrs.update(append_unique(attrs, name, cast_to_string(descriptor, value)))
113 def merge_attrs(param, update_attrs: dict):
114 all_attrs = collect_attributes(param)
115 mandatory_attrs = set(all_attrs.keys()).intersection(set(update_attrs.keys()))
116 return {value: update_attrs[value] for value in mandatory_attrs}