2 Copyright (C) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
23 from .constants import IMAGE_EXTENSIONS, BINARY_EXTENSIONS
24 from .logging import logger
28 if blob.layout != "NCHW":
30 channels = blob.shape[1]
34 def is_image_info(blob):
35 if blob.layout != "NC":
37 channels = blob.shape[1]
41 def get_inputs(path_to_input, batch_size, input_info, requests):
42 input_image_sizes = {}
43 for key in input_info.keys():
44 if is_image(input_info[key]):
45 input_image_sizes[key] = (input_info[key].shape[2], input_info[key].shape[3])
46 logger.info("Network input '{}' precision {}, dimensions ({}): {}".format(key,
47 input_info[key].precision,
48 input_info[key].layout,
49 " ".join(str(x) for x in
50 input_info[key].shape)))
52 images_count = len(input_image_sizes.keys())
53 binaries_count = len(input_info) - images_count
59 image_files = get_files_by_extensions(path_to_input, IMAGE_EXTENSIONS)
61 binary_files = get_files_by_extensions(path_to_input, BINARY_EXTENSIONS)
64 if (len(image_files) == 0) and (len(binary_files) == 0):
65 logger.warn("No input files were given: all inputs will be filled with random values!")
67 binary_to_be_used = binaries_count * batch_size * len(requests)
68 if binary_to_be_used > 0 and len(binary_files) == 0:
69 logger.warn("No supported binary inputs found! Please check your file extensions: {}".format(
70 ",".join(BINARY_EXTENSIONS)))
71 elif binary_to_be_used > len(binary_files):
73 "Some binary input files will be duplicated: {} files are required, but only {} were provided".format(
74 binary_to_be_used, len(binary_files)))
75 elif binary_to_be_used < len(binary_files):
77 "Some binary input files will be ignored: only {} files are required from {}".format(binary_to_be_used,
80 images_to_be_used = images_count * batch_size * len(requests)
81 if images_to_be_used > 0 and len(image_files) == 0:
82 logger.warn("No supported image inputs found! Please check your file extensions: {}".format(
83 ",".join(IMAGE_EXTENSIONS)))
84 elif images_to_be_used > len(image_files):
86 "Some image input files will be duplicated: {} files are required, but only {} were provided".format(
87 images_to_be_used, len(image_files)))
88 elif images_to_be_used < len(image_files):
90 "Some image input files will be ignored: only {} files are required from {}".format(images_to_be_used,
93 requests_input_data = []
94 for request_id in range(0, len(requests)):
95 logger.info("Infer Request {} filling".format(request_id))
97 keys = list(input_info.keys())
99 if is_image(input_info[key]):
101 if (len(image_files) > 0):
102 input_data[key] = fill_blob_with_image(image_files, request_id, batch_size, keys.index(key),
103 len(keys), input_info[key].shape)
107 if (len(binary_files) > 0):
108 input_data[key] = fill_blob_with_binary(binary_files, input_info[key].shape)
111 # most likely input is image info
112 if is_image_info(input_info[key]) and len(input_image_sizes) == 1:
113 image_size = input_image_sizes[list(input_image_sizes.keys()).pop()]
114 logger.info("Fill input '" + key + "' with image size " + str(image_size[0]) + "x" +
116 input_data[key] = fill_blob_with_image_info(image_size, input_info[key].shape)
119 # fill with random data
120 logger.info("Fill input '{}' with random values ({} is expected)".format(key, "image" if is_image(
121 input_info[key]) else "some binary data"))
122 input_data[key] = fill_blob_with_random(input_info[key].precision, input_info[key].shape)
124 requests_input_data.append(input_data)
126 return requests_input_data
129 def get_files_by_extensions(path_to_input, extensions):
131 if os.path.isfile(path_to_input):
132 input_files.append(path_to_input)
134 path = os.path.join(path_to_input, '*')
135 files = glob(path, recursive=True)
137 file_extension = file.rsplit('.').pop().upper()
138 if file_extension in extensions:
139 input_files.append(file)
143 def fill_blob_with_image(image_paths, request_id, batch_size, input_id, input_size, shape):
144 images = np.ndarray(shape)
145 image_index = request_id * batch_size * input_size + input_id
146 for b in range(batch_size):
147 image_index %= len(image_paths)
148 image_filename = image_paths[image_index]
149 logger.info('Prepare image {}'.format(image_filename))
150 image = cv2.imread(image_filename)
152 new_im_size = tuple(shape[2:])
153 if image.shape[:-1] != new_im_size:
154 logger.warn("Image is resized from ({}) to ({})".format(image.shape[:-1], new_im_size))
155 image = cv2.resize(image, new_im_size)
157 image = image.transpose((2, 1, 0))
160 image_index += input_size
164 def fill_blob_with_image_info(image_size, shape):
165 im_info = np.ndarray(shape)
166 for b in range(shape[0]):
167 for i in range(shape[1]):
168 im_info[b][i] = image_size[i] if i in [0, 1] else 1
173 def fill_blob_with_random(precision, shape):
174 if precision == "FP32":
175 return np.random.rand(*shape).astype(np.float32)
176 elif precision == "FP16":
177 return np.random.rand(*shape).astype(np.float16)
178 elif precision == "I32":
179 return np.random.rand(*shape).astype(np.int32)
180 elif precision == "U8":
181 return np.random.rand(*shape).astype(np.uint8)
182 elif precision == "I8":
183 return np.random.rand(*shape).astype(np.int8)
184 elif precision == "U16":
185 return np.random.rand(*shape).astype(np.uint16)
186 elif precision == "I16":
187 return np.random.rand(*shape).astype(np.int16)
189 raise Exception("Input precision is not supported: " + precision)