2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from ..config import ConfigError
21 from ..utils import extract_image_representations
24 def __init__(self, inputs_config, network_inputs, prepare_input_data=None):
25 def fit_to_input(data, input_layer):
26 if len(np.shape(data)) == 4:
27 return np.transpose(data, [0, 3, 1, 2])
30 self.input_transform_func = prepare_input_data or fit_to_input
31 self.network_inputs = network_inputs
32 self.configure(inputs_config)
34 def configure(self, inputs_config):
35 self.const_inputs, self.non_constant_inputs, self.inputs_mapping = self._parse_inputs_config(inputs_config)
36 if not self.non_constant_inputs:
37 raise ConfigError('Network should contain at least one layer for setting variable data.')
39 def fill_non_constant_inputs(self, data_representation_batch):
41 for input_layer in self.non_constant_inputs:
44 if self.inputs_mapping:
45 input_regex = self.inputs_mapping[input_layer]
46 for data_representation in data_representation_batch:
48 identifiers = data_representation.identifier
49 data = data_representation.data
50 if not isinstance(identifiers, list) and not input_regex:
52 input_batch.append(input_data)
56 raise ConfigError('Impossible to choose correct data for layer {}.'
57 'Please provide regular expression for matching in config.'.format(input_layer))
58 data = [data] if np.isscalar(identifiers) else data
59 identifiers = [identifiers] if np.isscalar(identifiers) else identifiers
60 for identifier, data_value in zip(identifiers, data):
61 if input_regex.match(identifier):
62 input_data = data_value
64 if input_data is None:
65 raise ConfigError('Suitable data for filling layer {} not found'.format(input_layer))
66 input_batch.append(input_data)
68 filled_inputs[input_layer] = input_batch
70 return self._transform_batch(filled_inputs, extract_image_representations(data_representation_batch)[1])
72 def _parse_inputs_config(self, inputs_entry):
74 non_constant_inputs_mapping = {}
75 non_constant_inputs = []
76 for input_ in inputs_entry:
78 if not name in self.network_inputs:
79 raise ConfigError('network does not contain input "{}"'.format(name))
80 value = input_['value']
82 if input_['type'] == 'CONST_INPUT':
83 if isinstance(value, list):
84 value = np.array(value)
85 constant_inputs[name] = value
87 value = re.compile(value)
88 non_constant_inputs_mapping[name] = value
90 non_constant_inputs = list(non_constant_inputs_mapping.keys())
91 not_config_inputs = list(filter(
92 lambda input_layer: not input_layer in non_constant_inputs + list(constant_inputs.keys()),
93 self.network_inputs.keys()
95 if non_constant_inputs and not_config_inputs:
96 raise ConfigError('input value for {} are not presented in config.'.format(','.join(not_config_inputs)))
97 non_constant_inputs += not_config_inputs
99 return constant_inputs, non_constant_inputs, non_constant_inputs_mapping or None
101 def _transform_batch(self, batch_data, meta):
102 def calculate_num_splits(layers_data, batch_size):
104 for _, data in layers_data.items():
107 total_tiles_num += len(tiles)
109 offset = 0 if total_tiles_num % batch_size == 0 else 1
110 splits_for_layer = (total_tiles_num // batch_size) + offset
111 if max_split_num < splits_for_layer:
112 max_split_num = splits_for_layer
116 def separate_data(data, num_splits):
117 grouped_data = [[] for _ in range(num_splits)]
118 for data_part in data:
119 for split_id, data_split in enumerate(data_part):
120 grouped_data[split_id % num_splits].append(data_split)
123 batch_size = len(meta)
124 if meta[0].get('multi_infer', False):
125 num_splits = calculate_num_splits(batch_data, batch_size)
126 infers_data = [{} for _ in range(num_splits)]
127 for layer_name, layer_data in batch_data.items():
128 batch_for_all_infers = separate_data(layer_data, num_splits)
129 for infer_id, on_infer_batch in enumerate(batch_for_all_infers):
130 infers_data[infer_id][layer_name] = self.input_transform_func(
131 on_infer_batch, self.network_inputs[layer_name]
135 for layer_name, layer_data in batch_data.items():
136 batch_data[layer_name] = self.input_transform_func(layer_data, self.network_inputs[layer_name])