2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from ..config import PathField, StringField, NumberField, BoolField
22 from .launcher import Launcher, LauncherConfig
24 DEVICE_REGEX = r'(?P<device>cpu$|gpu)(_(?P<identifier>\d+))?'
27 class CaffeLauncherConfig(LauncherConfig):
29 Specifies configuration structure for Caffe launcher.
34 device = StringField(regex=DEVICE_REGEX)
35 batch = NumberField(floats=False, min_value=1, optional=True)
36 output_name = StringField(optional=True)
37 allow_reshape_input = BoolField(optional=True)
40 class CaffeLauncher(Launcher):
42 Class for infer model using Caffe framework.
45 __provider__ = 'caffe'
47 def __init__(self, config_entry: dict, *args, **kwargs):
48 super().__init__(config_entry, *args, **kwargs)
50 caffe_launcher_config = CaffeLauncherConfig('Caffe_Launcher')
51 caffe_launcher_config.validate(self.config)
53 self.model = str(self.config['model'])
54 self.weights = str(self.config['weights'])
56 self.network = caffe.Net(self.model, self.weights, caffe.TEST)
57 self.allow_reshape_input = self.config.get('allow_reshape_input', False)
59 match = re.match(DEVICE_REGEX, self.config['device'].lower())
60 if match.group('device') == 'gpu':
62 identifier = match.group('identifier') or 0
63 caffe.set_device(int(identifier))
64 elif match.group('device') == 'cpu':
67 self._batch = self.config.get('batch', 1)
68 self.const_inputs = self.config.get('_list_const_inputs', [])
74 inputs in NCHW format.
76 self._inputs_shapes = {}
78 for input_blob in self.network.inputs:
79 if input_blob in self.const_inputs:
81 channels, height, width = self.network.blobs[input_blob].data.shape[1:]
82 self.network.blobs[input_blob].reshape(self._batch, channels, height, width)
83 self._inputs_shapes[input_blob] = channels, height, width
85 return self._inputs_shapes
92 def output_blob(self):
93 return next(iter(self.network.outputs))
95 def predict(self, inputs, metadata, *args, **kwargs):
98 inputs: dictionary where keys are input layers names and values are data for them.
99 metadata: metadata of input representations
101 raw data from network.
104 for infer_input in inputs:
105 for input_blob in self.network.inputs:
106 if input_blob in self.const_inputs:
109 data = infer_input[input_blob]
110 if self.allow_reshape_input:
111 self.network.blobs[input_blob].reshape(*data.shape)
113 if data.shape[0] != self._batch:
114 self.network.blobs[input_blob].reshape(
115 data.shape[0], *self.network.blobs[input_blob].data.shape[1:]
118 results.append(self.network.forward(**infer_input))
119 for image_meta in metadata:
120 self._provide_inputs_info_to_meta(image_meta)
124 def get_all_inputs(self):
126 for input_blob in self.network.inputs:
127 inputs_map[input_blob] = self.network.blobs[input_blob].data.shape