Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / ie_bridges / python / src / openvino / inference_engine / ie_api.pyx
1 #distutils: language=c++
2 from cython.operator cimport dereference as deref
3 from .cimport ie_api_impl_defs as C
4 from .ie_api_impl_defs cimport Blob, TensorDesc, SizeVector, Precision
5 from libcpp.string cimport string
6 from libcpp.vector cimport vector
7 from libcpp.pair cimport pair
8 from libcpp.map cimport map
9 from libcpp.memory cimport unique_ptr
10 from libc.stdint cimport int64_t
11 import os
12 import numpy as np
13 from copy import deepcopy
14 import warnings
15 from collections import OrderedDict
16
17 cdef extern from "<utility>" namespace "std" nogil:
18     cdef unique_ptr[C.IEExecNetwork] move(unique_ptr[C.IEExecNetwork])
19
20 cdef string to_std_string(str py_string):
21     return py_string.encode()
22
23 cdef to_py_string(const string & std_string):
24     return bytes(std_string).decode()
25
26 cdef dict_to_c_map(py_dict):
27     cdef map[string, string] c_map
28     for k, v in py_dict.items():
29         if type(k) != str or type(v) != str:
30             raise TypeError("Only string keys and values are allowed!")
31         c_map[k.encode()] = v.encode()
32     return c_map
33
34 supported_precisions = ["FP32", "FP16", "Q78", "I32", "I16", "I8", "U32", "U16"]
35 supported_layouts = ["NCHW", "NHWC", "OIHW", "C", "CHW", "HW", "NC", "CN", "BLOCKED", "NCDHW"]
36 known_plugins = ['CPU', 'GPU', 'FPGA', 'MYRIAD', 'HETERO', 'HDDL']
37
38 def get_version():
39     return C.get_version().decode()
40
41 cdef class IENetLayer:
42     @property
43     def name(self):
44         return self.impl.name.decode()
45     @property
46     def type(self):
47         return self.impl.type.decode()
48     @property
49     def precision(self):
50         return self.impl.precision.decode()
51     @property
52     def affinity(self):
53         return self.impl.affinity.decode()
54     @property
55     def weights(self):
56         cdef map[string, Blob.Ptr] c_weights_map
57         c_weights_map = self.impl.getWeights()
58         weights_map = {}
59         cdef BlobBuffer weights_buffer
60         for weights in c_weights_map:
61             weights_buffer = BlobBuffer()
62             weights_buffer.reset(weights.second)
63             weights_map[weights.first.decode()] = weights_buffer.to_numpy()
64         return weights_map
65
66     @property
67     def params(self):
68         return {k.decode(): v.decode() for k, v in self.impl.params}
69     @property
70     def parents(self):
71         cdef vector[string] c_parents = self.impl.parents
72         parents = []
73         return [parent.decode() for parent in c_parents]
74     @property
75     def children(self):
76         cdef vector[string] c_children = self.impl.children
77         children = []
78         return [child.decode() for child in c_children]
79     @property
80     def shape(self):
81         string_shape = self.impl.shape.decode()
82         return [int(i) for i in string_shape.split(' ')]
83     @property
84     def layout(self):
85         return self.impl.layout.decode()
86     @affinity.setter
87     def affinity(self, target_affinity):
88         self.impl.setAffinity(target_affinity.encode())
89     @params.setter
90     def params(self, params_map):
91         self.impl.setParams(dict_to_c_map(params_map))
92
93     @precision.setter
94     def precision(self, precision: str):
95         self.impl.setPrecision(precision.upper().encode())
96
97 cdef class InputInfo:
98     @property
99     def precision(self):
100         return self.impl.precision.decode()
101     @property
102     def layout(self):
103         return self.impl.layout.decode()
104     @property
105     def shape(self):
106         return self.impl.dims
107
108     @precision.setter
109     def precision(self, precision):
110         if precision.upper() not in supported_precisions:
111             raise AttributeError(
112                 "Unsupported precision {}! List of supported precisions: {}".format(precision, supported_precisions))
113         self.impl.setPrecision(precision.encode())
114     @layout.setter
115     def layout(self, layout):
116         if layout.upper() not in supported_layouts:
117             raise AttributeError(
118                 "Unsupported layout {}! List of supported layouts: {}".format(layout, supported_layouts))
119         self.impl.setLayout(layout.encode())
120
121 cdef class OutputInfo:
122     @property
123     def precision(self):
124         return self.impl.precision.decode()
125     @property
126     def layout(self):
127         return self.impl.layout.decode()
128     @property
129     def shape(self):
130         return self.impl.dims
131     @precision.setter
132     def precision(self, precision):
133         if precision.upper() not in supported_precisions:
134             raise AttributeError(
135                 "Unsupported precision {}! List of supported precisions: {}".format(precision, supported_precisions))
136         self.impl.setPrecision(precision.encode())
137
138 cdef class ExecutableNetwork:
139     def __init__(self):
140         self._requests = []
141         self.inputs = []
142         self.outputs = []
143
144     def infer(self, inputs=None):
145         current_request = self.requests[0]
146         current_request.infer(inputs)
147         return deepcopy(current_request.outputs)
148
149     def start_async(self, request_id, inputs=None):
150         if request_id not in list(range(len(self.requests))):
151             raise ValueError("Incorrect request_id specified!")
152         current_request = self.requests[request_id]
153         current_request.async_infer(inputs)
154         return current_request
155
156     @property
157     def requests(self):
158         requests = []
159         for i in range(deref(self.impl).infer_requests.size()):
160             infer_request = InferRequest()
161             infer_request.impl = &(deref(self.impl).infer_requests[i])
162             infer_request._inputs_list = self.inputs
163             infer_request._outputs_list = self.outputs
164             requests.append(infer_request)
165         return requests
166
167 cdef class InferRequest:
168     def __init__(self):
169         self._inputs_list = []
170         self._outputs_list = []
171
172     cpdef BlobBuffer _get_blob_buffer(self, const string & blob_name):
173         cdef BlobBuffer buffer = BlobBuffer()
174         cdef Blob.Ptr blob_ptr
175         deref(self.impl).getBlobPtr(blob_name, blob_ptr)
176         buffer.reset(blob_ptr)
177         return buffer
178
179     cpdef infer(self, inputs=None):
180         if inputs is not None:
181             self._fill_inputs(inputs)
182
183         deref(self.impl).infer()
184
185     cpdef async_infer(self, inputs=None):
186         if inputs is not None:
187             self._fill_inputs(inputs)
188
189         deref(self.impl).infer_async()
190
191     cpdef wait(self, timeout=None):
192         if timeout is None:
193             timeout = -1
194         return deref(self.impl).wait(<int64_t> timeout)
195
196     cpdef get_perf_counts(self):
197         cdef map[string, C.ProfileInfo] c_profile = deref(self.impl).getPerformanceCounts()
198         profile = {}
199         for l in c_profile:
200             info = l.second
201             # TODO: add execution index. Check if unsigned int is properly converted to int in python.
202             profile[l.first.decode()] = {"status": info.status.decode(), "exec_type": info.exec_type.decode(),
203                                          "layer_type": info.layer_type.decode(), "real_time": info.real_time,
204                                          "cpu_time": info.cpu_time}
205         return profile
206
207     @property
208     def inputs(self):
209         inputs = {}
210         for input in self._inputs_list:
211             inputs[input] = self._get_blob_buffer(input.encode()).to_numpy()
212         return inputs
213
214     @property
215     def outputs(self):
216         outputs = {}
217         for output in self._outputs_list:
218             outputs[output] = self._get_blob_buffer(output.encode()).to_numpy()
219         return deepcopy(outputs)
220
221     @property
222     def latency(self):
223         return self.impl.exec_time
224
225     def set_batch(self, size):
226         if size <= 0:
227             raise ValueError("Batch size should be positive integer number but {} specified".format(size))
228         deref(self.impl).setBatch(size)
229
230     def _fill_inputs(self, inputs):
231         for k, v in inputs.items():
232             assert k in self._inputs_list, "No input with name {} found in network".format(k)
233             self.inputs[k][:] = v
234
235
236 class LayerStats:
237     def __init__(self, min: tuple = (), max: tuple = ()):
238         self._min = min
239         self._max = max
240
241     @property
242     def min(self):
243         return self._min
244     @property
245     def max(self):
246         return self._max
247
248
249 cdef class LayersStatsMap(dict):
250     def update(self, other=None, **kwargs):
251         super(LayersStatsMap, self).update(other, **kwargs)
252         cdef map[string, map[string, vector[float]]] c_stats_map
253         cdef map[string, vector[float]] c_node_stats
254         for k, v in self.items():
255             c_node_stats["min".encode()] = v.min
256             c_node_stats["max".encode()] = v.max
257             c_stats_map[k.encode()] = c_node_stats
258         self.net_impl.setStats(c_stats_map)
259
260 cdef class IENetwork:
261     def __cinit__(self, model: str="", weights: str=""):
262         cdef string model_
263         cdef string weights_
264         if model and weights:
265             if not os.path.isfile(model):
266                 raise Exception("Path to the model {} doesn't exists or it's a directory".format(model))
267             if not os.path.isfile(weights):
268                 raise Exception("Path to the weights {} doesn't exists or it's a directory".format(weights))
269             model_ = model.encode()
270             weights_ = weights.encode()
271             self.impl = C.IENetwork(model_, weights_)
272         else:
273             self.impl = C.IENetwork()
274     @property
275     def name(self):
276         name = bytes(self.impl.name)
277         return name.decode()
278
279     @property
280     def inputs(self):
281         cdef map[string, C.InputInfo] c_inputs = self.impl.getInputs()
282         inputs = {}
283         cdef InputInfo in_info
284         for input in c_inputs:
285             in_info = InputInfo()
286             in_info.impl = input.second
287             inputs[input.first.decode()] = in_info
288         return inputs
289
290     @property
291     def outputs(self):
292         cdef map[string, C.OutputInfo] c_outputs = self.impl.getOutputs()
293         outputs = {}
294         cdef OutputInfo out_info
295         for out in c_outputs:
296             out_info = OutputInfo()
297             out_info.impl = out.second
298             outputs[out.first.decode()] = out_info
299         return outputs
300
301     @property
302     def batch_size(self):
303         return self.impl.batch_size
304
305     @batch_size.setter
306     def batch_size(self, batch: int):
307         if batch <= 0:
308             raise AttributeError("Invalid batch size {}! Batch size should be positive integer value".format(batch))
309         self.impl.setBatch(batch)
310         self.impl.batch_size = batch
311
312     @property
313     def layers(self):
314         cdef vector[pair[string, C.IENetLayer]] c_layers = self.impl.getLayers()
315         layers = OrderedDict()
316         cdef IENetLayer net_l = IENetLayer()
317         for l in c_layers:
318             net_l = IENetLayer()
319             net_l.impl = l.second
320             layers[l.first.decode()] = net_l
321         return layers
322     @property
323     def stats(self):
324         cdef map[string, map[string, vector[float]]] c_stats_map = self.impl.getStats()
325         py_stats_map = LayersStatsMap()
326         py_stats_map.net_impl = self.impl
327         for it in c_stats_map:
328             stats_map = LayersStatsMap()
329             py_stats_map[it.first.decode()] = LayerStats(min=tuple(it.second["min".encode()]),
330                                                          max=tuple(it.second["max".encode()]))
331         return py_stats_map
332
333     @classmethod
334     def from_ir(cls, model: str, weights: str):
335         warnings.filterwarnings("always",category=DeprecationWarning)
336         warnings.warn("from_ir() method of IENetwork is deprecated. "
337                       "Please use IENetwork class constructor to create valid IENetwork instance",
338                       DeprecationWarning)
339         if not os.path.isfile(model):
340             raise Exception("Path to the model {} doesn't exists or it's a directory".format(model))
341         if not os.path.isfile(weights):
342             raise Exception("Path to the weights {} doesn't exists or it's a directory".format(weights))
343         cdef IENetwork net = IENetwork(model, weights)
344         return net
345
346     # TODO: Use enum with precision type instead of srting parameter when python2 support will not be required.
347     def add_outputs(self, outputs, precision="FP32"):
348         if precision.upper() not in supported_precisions:
349             raise AttributeError(
350                 "Unsupported precision {}! List of supported precisions: {}".format(precision, supported_precisions))
351         if not isinstance(outputs, list):
352             outputs = [outputs]
353         cdef vector[string] _outputs
354         for l in outputs:
355             _outputs.push_back(l.encode())
356         self.impl.addOutputs(_outputs, precision.upper().encode())
357
358     def serialize(self, path_to_xml, path_to_bin):
359         self.impl.serialize(path_to_xml.encode(), path_to_bin.encode())
360     def reshape(self, input_shapes: dict):
361         cdef map[string, vector[size_t]] c_input_shapes;
362         cdef vector[size_t] c_shape
363         net_inputs = self.inputs
364         for input, shape in input_shapes.items():
365             c_shape = []
366             if input not in net_inputs:
367                 raise AttributeError("Specified {} layer not in network inputs {}! ".format(input, net_inputs))
368             for v in shape:
369                 c_shape.push_back(v)
370             c_input_shapes[input.encode()] = c_shape
371         self.impl.reshape(c_input_shapes)
372
373 cdef class IEPlugin:
374     def __cinit__(self, device: str, plugin_dirs=None):
375         plugin_base = device.split(':')[0]
376         if plugin_base not in known_plugins:
377             raise ValueError("Unknown plugin: {}, expected one of: {}"
378                              .format(plugin_base, ",".join(known_plugins)))
379         if plugin_dirs is None:
380             plugin_dirs = [""]
381         elif isinstance(plugin_dirs, str):
382             plugin_dirs = [plugin_dirs]
383
384         # add package directory to plugin_dirs
385         lib_location = os.path.dirname(os.path.realpath(__file__))
386         plugin_dirs.append(lib_location)
387
388         cpdef string device_ = <string> device.encode()
389         cdef vector[string] dirs_
390         for d in plugin_dirs:
391             dirs_.push_back(<string> d.encode())
392
393         self.impl = C.IEPlugin(device_, dirs_)
394
395     cpdef ExecutableNetwork load(self, IENetwork network, int num_requests=1, config=None):
396         if num_requests <= 0:
397             raise ValueError(
398                 "Incorrect number of requests specified: {}. Expected positive integer number.".format(num_requests))
399         cdef ExecutableNetwork exec_net = ExecutableNetwork()
400         cdef map[string, string] c_config
401
402         if config:
403             for k, v in config.items():
404                 c_config[to_std_string(k)] = to_std_string(v)
405         exec_net.plugin_impl = self.impl
406         exec_net.impl = move(self.impl.load(network.impl, num_requests, c_config))
407         exec_net.inputs = network.inputs.keys()
408         exec_net.outputs = list(network.outputs.keys())
409         return exec_net
410
411     cpdef void set_initial_affinity(self, IENetwork net) except *:
412         if self.device.find("HETERO") == -1:
413             raise RuntimeError("set_initial_affinity method applicable only for HETERO device")
414         self.impl.setInitialAffinity(net.impl)
415
416     cpdef set get_supported_layers(self, IENetwork net):
417         return set([l.decode() for l in self.impl.queryNetwork(net.impl)])
418
419     @property
420     def device(self):
421         device_name = bytes(self.impl.device_name)
422         return to_py_string(device_name)
423
424     @property
425     def version(self):
426         version = bytes(self.impl.version)
427         return version.decode()
428
429     cpdef void add_cpu_extension(self, str extension_path) except *:
430         if self.device.find("CPU") == -1:
431             raise RuntimeError("add_cpu_extension method applicable only for CPU or HETERO devices")
432         cdef string extension_str = extension_path.encode()
433         self.impl.addCpuExtension(extension_str)
434
435     cpdef void set_config(self, config):
436         cdef map[string, string] c_config
437         for k, v in config.items():
438             c_config[to_std_string(k)] = to_std_string(v)
439         self.impl.setConfig(c_config)
440
441
442 cdef class BlobBuffer:
443     """Copy-less accessor for Inference Engine Blob"""
444
445     cdef reset(self, Blob.Ptr & ptr):
446         self.ptr = ptr
447         cdef TensorDesc desc = deref(ptr).getTensorDesc()
448         cdef SizeVector shape = desc.getDims()
449         cdef Py_ssize_t itemsize = deref(ptr).element_size()
450         self.strides.resize(shape.size())
451         self.shape.resize(shape.size())
452
453         total_stride = itemsize
454         # dims are in row major (C - style),
455         # thence strides are computed starting from latest dimension
456         for i in reversed(range(shape.size())):
457             self.strides[i] = total_stride
458             self.shape[i] = shape[i]
459             total_stride *= shape[i]
460
461         self.total_stride = total_stride
462         self.format = self._get_blob_format(desc)
463         self.item_size = itemsize
464
465     def __getbuffer__(self, Py_buffer *buffer, int flags):
466         buffer.buf = C.get_buffer[char](deref(self.ptr))
467         buffer.format = self.format
468         buffer.internal = NULL
469         buffer.itemsize = self.item_size
470         buffer.len = self.total_stride
471         buffer.ndim = self.shape.size()
472         buffer.obj = self
473         buffer.readonly = 0
474         buffer.shape = self.shape.data()
475         buffer.strides = self.strides.data()
476         buffer.suboffsets = NULL
477
478     cdef char*_get_blob_format(self, const TensorDesc & desc):
479         cdef Precision precision = desc.getPrecision()
480         name = bytes(precision.name()).decode()
481         # todo: half floats
482         precision_to_format = {
483             'FP32': 'f',  # float
484             'FP16': 'h',  # signed short
485             'Q78': 'h',  # signed short
486             'I16': 'h',  # signed short
487             'U8': 'B',  # unsigned char
488             'I8': 'b',  # signed char
489             'U16': 'H',  # unsigned short
490             'I32': 'i'  # signed int
491         }
492
493         if name not in precision_to_format:
494             raise ValueError("Unknown Blob precision: {}".format(name))
495
496         return precision_to_format[name].encode()
497
498     def to_numpy(self):
499         return np.asarray(self)