2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 def dim_to_shape(dim):
22 Extracts proto message with shape dimensions to shape expressed as np.array.
24 dim: proto message with shape dimensions
27 shape of the layer as np.array
29 return np.array(dim, dtype=np.int64)
32 def embed_input(attrs: dict, port: int, name: str, value: np.array, bin_name: str = None):
34 Appends port information to the given set of attributes of the current layer.
35 Mutates passed attributes.
37 attrs: dictionary of existing attributes
38 port: relative number of the port for the layer
39 name: name of the input
40 value: np.array of values
41 bin_name: optional, representing the specific behavior of the blob,
42 either 'weights' or 'biases'
45 mutated attributes dictionary with new properties under 'embedded_inputs' key
48 assert name not in attrs
49 attrs[name] = np.array(value)
51 if 'embedded_inputs' not in attrs:
52 attrs['embedded_inputs'] = []
55 input_val = (port, name, {'bin': bin_name})
56 # (input index, input name, future edge attributes)
57 attrs['embedded_inputs'].append(input_val) # pylint: disable=not-callable
60 def weights_biases(bias_term: bool, model_layer, start_index: int = 1, proto={}):
62 Creates object with configured inputs in the following order: 0: weights, 1: biases
64 bias_term: flag to whether include biases in the final input or not
65 model_layer: caffemodel layer containing values in blobs
68 dictionary with set up inputs or empty dictionary
73 if proto.weight_filler:
74 if proto.weight_filler.type == "diagonal":
75 data_len = proto.kernel_size[0] * proto.kernel_size[0] * proto.num_output
76 data = np.zeros(data_len * data_len, dtype=np.float32)
77 for i in range(0, data_len):
78 data[i * (data_len + 1)] = proto.weight_filler.diag_val[i]
80 bias = np.zeros(proto.num_output, np.float32)
81 embed_input(attrs, start_index, 'weights', data)
83 embed_input(attrs, start_index + 1, 'biases', bias)
87 blobs = model_layer.blobs
88 embed_input(attrs, start_index, 'weights', blobs[0].data)
90 embed_input(attrs, start_index + 1, 'biases', blobs[1].data)
94 def get_list_from_container(param, prop: str, t):
96 Takes proto parameter and extracts a value it stores.
98 param: proto parameter
99 prop: name of the property to take
100 t: type of the value (int, float etc.) - only primitive ones
103 If it is a container, returns the list with values.
104 If it is a single value of the given type - a list of single value.
105 If neither or property does not exist for param - empty list.
107 if not param or (param and not hasattr(param, prop)):
110 prop_val = getattr(param, prop)
114 elif isinstance(prop_val, t):
116 elif len(prop_val) > 0:
121 def get_spatial_attr(default: list, single_name: str, name: str, param):
124 if hasattr(param, '{}_h'.format(name)):
125 if getattr(param, '{}_h'.format(name)) != default[1] and getattr(param, '{}_h'.format(name)) != 0:
126 attr_h = getattr(param, '{}_h'.format(name))
127 if hasattr(param, '{}_w'.format(name)):
128 if getattr(param, '{}_w'.format(name)) != default[0] and getattr(param, '{}_w'.format(name)) != 0:
129 attr_w = getattr(param, '{}_w'.format(name))
130 if (not attr_h or not attr_w) or (attr_h == attr_w == default[0]):
131 attrs = get_list_from_container(param, single_name, int)
132 if len(attrs) > 0 and attrs != default:
133 attr_w = attr_h = attrs[0]
134 return attr_w, attr_h
137 def merge_attrs(all_attrs: dict, update_attrs: dict):
138 mandatory_attrs = set(all_attrs.keys()).intersection(set(update_attrs.keys()))
139 return {value: update_attrs[value] for value in mandatory_attrs}
142 def get_canonical_axis_index(shape, axis):
143 return len(shape) + axis if axis < 0 else axis