b9c28c50394a5d9c932789cd488ed84b8bb792fc
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / launcher / caffe_launcher.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 import re
18
19 import caffe
20
21 from ..config import PathField, StringField, NumberField, BoolField
22 from .launcher import Launcher, LauncherConfig
23
24 DEVICE_REGEX = r'(?P<device>cpu$|gpu)(_(?P<identifier>\d+))?'
25
26
27 class CaffeLauncherConfig(LauncherConfig):
28     """
29     Specifies configuration structure for Caffe launcher.
30     """
31
32     model = PathField()
33     weights = PathField()
34     device = StringField(regex=DEVICE_REGEX)
35     batch = NumberField(floats=False, min_value=1, optional=True)
36     output_name = StringField(optional=True)
37     allow_reshape_input = BoolField(optional=True)
38
39
40 class CaffeLauncher(Launcher):
41     """
42     Class for infer model using Caffe framework.
43     """
44
45     __provider__ = 'caffe'
46
47     def __init__(self, config_entry: dict, *args, **kwargs):
48         super().__init__(config_entry, *args, **kwargs)
49
50         caffe_launcher_config = CaffeLauncherConfig('Caffe_Launcher')
51         caffe_launcher_config.validate(self.config)
52
53         self.model = str(self.config['model'])
54         self.weights = str(self.config['weights'])
55
56         self.network = caffe.Net(self.model, self.weights, caffe.TEST)
57         self.allow_reshape_input = self.config.get('allow_reshape_input', False)
58
59         match = re.match(DEVICE_REGEX, self.config['device'].lower())
60         if match.group('device') == 'gpu':
61             caffe.set_mode_gpu()
62             identifier = match.group('identifier') or 0
63             caffe.set_device(int(identifier))
64         elif match.group('device') == 'cpu':
65             caffe.set_mode_cpu()
66
67         self._batch = self.config.get('batch', 1)
68         self.const_inputs = self.config.get('_list_const_inputs', [])
69
70     @property
71     def inputs(self):
72         """
73         Returns:
74             inputs in NCHW format.
75         """
76         self._inputs_shapes = {}
77
78         for input_blob in self.network.inputs:
79             if input_blob in self.const_inputs:
80                 continue
81             channels, height, width = self.network.blobs[input_blob].data.shape[1:]
82             self.network.blobs[input_blob].reshape(self._batch, channels, height, width)
83             self._inputs_shapes[input_blob] = channels, height, width
84
85         return self._inputs_shapes
86
87     @property
88     def batch(self):
89         return self._batch
90
91     @property
92     def output_blob(self):
93         return next(iter(self.network.outputs))
94
95     def predict(self, inputs, metadata, *args, **kwargs):
96         """
97         Args:
98             inputs: dictionary where keys are input layers names and values are data for them.
99             metadata: metadata of input representations
100         Returns:
101             raw data from network.
102         """
103         results = []
104         for infer_input in inputs:
105             for input_blob in self.network.inputs:
106                 if input_blob in self.const_inputs:
107                     continue
108
109                 data = infer_input[input_blob]
110                 if self.allow_reshape_input:
111                     self.network.blobs[input_blob].reshape(*data.shape)
112
113                 if data.shape[0] != self._batch:
114                     self.network.blobs[input_blob].reshape(
115                         data.shape[0], *self.network.blobs[input_blob].data.shape[1:]
116                     )
117
118             results.append(self.network.forward(**infer_input))
119             for image_meta in metadata:
120                 self._provide_inputs_info_to_meta(image_meta)
121
122         return results
123
124     def get_all_inputs(self):
125         inputs_map = {}
126         for input_blob in self.network.inputs:
127             inputs_map[input_blob] = self.network.blobs[input_blob].data.shape
128
129         return inputs_map
130
131     def release(self):
132         """
133         Releases launcher.
134         """
135         del self.network