2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from ..adapters import Adapter
20 from ..config import ConfigValidator, StringField
21 from ..representation import (
24 ClassificationPrediction,
25 FacialLandmarksPrediction,
26 MultiLabelRecognitionPrediction,
31 class HeadPoseEstimatorAdapterConfig(ConfigValidator):
33 angle_yaw = StringField()
34 angle_pitch = StringField()
35 angle_roll = StringField()
38 class HeadPoseEstimatorAdapter(Adapter):
40 Class for converting output of HeadPoseEstimator to HeadPosePrediction representation
42 __provider__ = 'head_pose'
44 def validate_config(self):
45 head_pose_estimator_adapter_config = HeadPoseEstimatorAdapterConfig(
46 'HeadPoseEstimator_Config', on_extra_argument=HeadPoseEstimatorAdapterConfig.ERROR_ON_EXTRA_ARGUMENT)
47 head_pose_estimator_adapter_config.validate(self.launcher_config)
51 Specifies parameters of config entry
53 self.angle_yaw = self.launcher_config['angle_yaw']
54 self.angle_pitch = self.launcher_config['angle_pitch']
55 self.angle_roll = self.launcher_config['angle_roll']
57 def process(self, raw, identifiers=None, frame_meta=None):
60 identifiers: list of input data identifiers
62 frame_meta: list of meta information about each frame
64 list of ContainerPrediction objects
67 raw_output = self._extract_predictions(raw, frame_meta)
68 for identifier, yaw, pitch, roll in zip(
70 raw_output[self.angle_yaw],
71 raw_output[self.angle_pitch],
72 raw_output[self.angle_roll]
74 prediction = ContainerPrediction({'angle_yaw': RegressionPrediction(identifier, yaw[0]),
75 'angle_pitch': RegressionPrediction(identifier, pitch[0]),
76 'angle_roll': RegressionPrediction(identifier, roll[0])})
77 result.append(prediction)
82 class VehicleAttributesRecognitionAdapterConfig(ConfigValidator):
84 color_out = StringField()
85 type_out = StringField()
88 class VehicleAttributesRecognitionAdapter(Adapter):
89 __provider__ = 'vehicle_attributes'
91 def validate_config(self):
92 attributes_recognition_adapter_config = VehicleAttributesRecognitionAdapterConfig(
93 'VehicleAttributesRecognition_Config',
94 on_extra_argument=VehicleAttributesRecognitionAdapterConfig.ERROR_ON_EXTRA_ARGUMENT)
95 attributes_recognition_adapter_config.validate(self.launcher_config)
99 Specifies parameters of config entry
101 self.color_out = self.launcher_config['color_out']
102 self.type_out = self.launcher_config['type_out']
104 def process(self, raw, identifiers=None, frame_meta=None):
106 raw_output = self._extract_predictions(raw, frame_meta)
107 for identifier, colors, types in zip(identifiers, raw_output[self.color_out], raw_output[self.type_out]):
108 res.append(ContainerPrediction({'color': ClassificationPrediction(identifier, colors.reshape(-1)),
109 'type': ClassificationPrediction(identifier, types.reshape(-1))}))
113 class AgeGenderAdapterConfig(ConfigValidator):
115 age_out = StringField()
116 gender_out = StringField()
119 class AgeGenderAdapter(Adapter):
120 __provider__ = 'age_gender'
123 self.age_out = self.launcher_config['age_out']
124 self.gender_out = self.launcher_config['gender_out']
126 def validate_config(self):
127 age_gender_adapter_config = AgeGenderAdapterConfig(
128 'AgeGender_Config', on_extra_argument=AgeGenderAdapterConfig.ERROR_ON_EXTRA_ARGUMENT)
129 age_gender_adapter_config.validate(self.launcher_config)
132 def get_age_scores(age):
133 age_scores = np.zeros(4)
146 def process(self, raw, identifiers=None, frame_meta=None):
148 raw_output = self._extract_predictions(raw, frame_meta)
149 for identifier, age, gender in zip(identifiers, raw_output[self.age_out], raw_output[self.gender_out]):
150 gender = gender.reshape(-1)
151 age = age.reshape(-1)[0]*100
152 gender_rep = ClassificationPrediction(identifier, gender)
153 age_class_rep = ClassificationPrediction(identifier, self.get_age_scores(age))
154 age_error_rep = RegressionPrediction(identifier, age)
155 result.append(ContainerPrediction({'gender': gender_rep, 'age_classification': age_class_rep,
156 'age_error': age_error_rep}))
160 class LandmarksRegressionAdapter(Adapter):
161 __provider__ = 'landmarks_regression'
163 def process(self, raw, identifiers=None, frame_meta=None):
165 raw_output = self._extract_predictions(raw, frame_meta)
166 for identifier, values in zip(identifiers, raw_output[self.output_blob]):
167 x_values, y_values = values[::2], values[1::2]
168 res.append(FacialLandmarksPrediction(identifier, x_values.reshape(-1), y_values.reshape(-1)))
172 class PersonAttributesConfig(ConfigValidator):
173 attributes_recognition_out = StringField(optional=True)
176 class PersonAttributesAdapter(Adapter):
177 __provider__ = 'person_attributes'
179 def validate_config(self):
180 person_attributes_adapter_config = PersonAttributesConfig(
181 'PersonAttributes_Config',
182 PersonAttributesConfig.IGNORE_ON_EXTRA_ARGUMENT
184 person_attributes_adapter_config.validate(self.launcher_config)
187 self.attributes_recognition_out = self.launcher_config.get('attributes_recognition_out', self.output_blob)
189 def process(self, raw, identifiers=None, frame_meta=None):
191 raw_output = self._extract_predictions(raw, frame_meta)
192 for identifier, multi_label in zip(identifiers, raw_output[self.attributes_recognition_out or self.output_blob]):
193 multi_label[multi_label > 0.5] = 1.
194 multi_label[multi_label <= 0.5] = 0.
196 result.append(MultiLabelRecognitionPrediction(identifier, multi_label.reshape(-1)))
201 class GazeEstimationAdapter(Adapter):
202 __provider__ = 'gaze_estimation'
204 def process(self, raw, identifiers=None, frame_meta=None):
206 raw_output = self._extract_predictions(raw, frame_meta)
207 for identifier, output in zip(identifiers, raw_output[self.output_blob]):
208 result.append(GazeVectorPrediction(identifier, output))