Publishing 2019 R3 content
[platform/upstream/dldt.git] / tools / benchmark / utils / inputs_filling.py
1 """
2  Copyright (C) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import os
18 import cv2
19 import numpy as np
20
21 from glob import glob
22
23 from .constants import IMAGE_EXTENSIONS, BINARY_EXTENSIONS
24 from .logging import logger
25
26
27 def is_image(blob):
28     if blob.layout != "NCHW":
29         return False
30     channels = blob.shape[1]
31     return channels == 3
32
33
34 def is_image_info(blob):
35     if blob.layout != "NC":
36         return False
37     channels = blob.shape[1]
38     return channels >= 2
39
40
41 def get_inputs(path_to_input, batch_size, input_info, requests):
42     input_image_sizes = {}
43     for key in input_info.keys():
44         if is_image(input_info[key]):
45             input_image_sizes[key] = (input_info[key].shape[2], input_info[key].shape[3])
46         logger.info("Network input '{}' precision {}, dimensions ({}): {}".format(key,
47                                                                                   input_info[key].precision,
48                                                                                   input_info[key].layout,
49                                                                                   " ".join(str(x) for x in
50                                                                                            input_info[key].shape)))
51
52     images_count = len(input_image_sizes.keys())
53     binaries_count = len(input_info) - images_count
54
55     image_files = list()
56     binary_files = list()
57
58     if path_to_input:
59         image_files = get_files_by_extensions(path_to_input, IMAGE_EXTENSIONS)
60         image_files.sort()
61         binary_files = get_files_by_extensions(path_to_input, BINARY_EXTENSIONS)
62         binary_files.sort()
63
64     if (len(image_files) == 0) and (len(binary_files) == 0):
65         logger.warn("No input files were given: all inputs will be filled with random values!")
66     else:
67         binary_to_be_used = binaries_count * batch_size * len(requests)
68         if binary_to_be_used > 0 and len(binary_files) == 0:
69             logger.warn("No supported binary inputs found! Please check your file extensions: {}".format(
70                 ",".join(BINARY_EXTENSIONS)))
71         elif binary_to_be_used > len(binary_files):
72             logger.warn(
73                 "Some binary input files will be duplicated: {} files are required, but only {} were provided".format(
74                     binary_to_be_used, len(binary_files)))
75         elif binary_to_be_used < len(binary_files):
76             logger.warn(
77                 "Some binary input files will be ignored: only {} files are required from {}".format(binary_to_be_used,
78                                                                                                      len(binary_files)))
79
80         images_to_be_used = images_count * batch_size * len(requests)
81         if images_to_be_used > 0 and len(image_files) == 0:
82             logger.warn("No supported image inputs found! Please check your file extensions: {}".format(
83                 ",".join(IMAGE_EXTENSIONS)))
84         elif images_to_be_used > len(image_files):
85             logger.warn(
86                 "Some image input files will be duplicated: {} files are required, but only {} were provided".format(
87                     images_to_be_used, len(image_files)))
88         elif images_to_be_used < len(image_files):
89             logger.warn(
90                 "Some image input files will be ignored: only {} files are required from {}".format(images_to_be_used,
91                                                                                                     len(image_files)))
92
93     requests_input_data = []
94     for request_id in range(0, len(requests)):
95         logger.info("Infer Request {} filling".format(request_id))
96         input_data = {}
97         keys = list(input_info.keys())
98         for key in keys:
99             if is_image(input_info[key]):
100                 # input is image
101                 if (len(image_files) > 0):
102                     input_data[key] = fill_blob_with_image(image_files, request_id, batch_size, keys.index(key),
103                                                            len(keys), input_info[key].shape)
104                     continue
105
106             # input is binary
107             if (len(binary_files) > 0):
108                 input_data[key] = fill_blob_with_binary(binary_files, input_info[key].shape)
109                 continue
110
111             # most likely input is image info
112             if is_image_info(input_info[key]) and len(input_image_sizes) == 1:
113                 image_size = input_image_sizes[list(input_image_sizes.keys()).pop()]
114                 logger.info("Fill input '" + key + "' with image size " + str(image_size[0]) + "x" +
115                             str(image_size[1]))
116                 input_data[key] = fill_blob_with_image_info(image_size, input_info[key].shape)
117                 continue
118
119             # fill with random data
120             logger.info("Fill input '{}' with random values ({} is expected)".format(key, "image" if is_image(
121                 input_info[key]) else "some binary data"))
122             input_data[key] = fill_blob_with_random(input_info[key].precision, input_info[key].shape)
123
124         requests_input_data.append(input_data)
125
126     return requests_input_data
127
128
129 def get_files_by_extensions(path_to_input, extensions):
130     input_files = list()
131     if os.path.isfile(path_to_input):
132         input_files.append(path_to_input)
133     else:
134         path = os.path.join(path_to_input, '*')
135         files = glob(path, recursive=True)
136         for file in files:
137             file_extension = file.rsplit('.').pop().upper()
138             if file_extension in extensions:
139                 input_files.append(file)
140     return input_files
141
142
143 def fill_blob_with_image(image_paths, request_id, batch_size, input_id, input_size, shape):
144     images = np.ndarray(shape)
145     image_index = request_id * batch_size * input_size + input_id
146     for b in range(batch_size):
147         image_index %= len(image_paths)
148         image_filename = image_paths[image_index]
149         logger.info('Prepare image {}'.format(image_filename))
150         image = cv2.imread(image_filename)
151
152         new_im_size = tuple(shape[2:])
153         if image.shape[:-1] != new_im_size:
154             logger.warn("Image is resized from ({}) to ({})".format(image.shape[:-1], new_im_size))
155             image = cv2.resize(image, new_im_size)
156
157         image = image.transpose((2, 1, 0))
158         images[b] = image
159
160         image_index += input_size
161     return images
162
163
164 def fill_blob_with_image_info(image_size, shape):
165     im_info = np.ndarray(shape)
166     for b in range(shape[0]):
167         for i in range(shape[1]):
168             im_info[b][i] = image_size[i] if i in [0, 1] else 1
169
170     return im_info
171
172
173 def fill_blob_with_random(precision, shape):
174     if precision == "FP32":
175         return np.random.rand(*shape).astype(np.float32)
176     elif precision == "FP16":
177         return np.random.rand(*shape).astype(np.float16)
178     elif precision == "I32":
179         return np.random.rand(*shape).astype(np.int32)
180     elif precision == "U8":
181         return np.random.rand(*shape).astype(np.uint8)
182     elif precision == "I8":
183         return np.random.rand(*shape).astype(np.int8)
184     elif precision == "U16":
185         return np.random.rand(*shape).astype(np.uint16)
186     elif precision == "I16":
187         return np.random.rand(*shape).astype(np.int16)
188     else:
189         raise Exception("Input precision is not supported: " + precision)