Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / adapters / attributes_recognition.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 import numpy as np
18
19 from ..adapters import Adapter
20 from ..config import ConfigValidator, StringField
21 from ..representation import (
22     ContainerPrediction,
23     RegressionPrediction,
24     ClassificationPrediction,
25     FacialLandmarksPrediction,
26     MultiLabelRecognitionPrediction,
27     GazeVectorPrediction
28 )
29
30
31 class HeadPoseEstimatorAdapterConfig(ConfigValidator):
32     type = StringField()
33     angle_yaw = StringField()
34     angle_pitch = StringField()
35     angle_roll = StringField()
36
37
38 class HeadPoseEstimatorAdapter(Adapter):
39     """
40     Class for converting output of HeadPoseEstimator to HeadPosePrediction representation
41     """
42     __provider__ = 'head_pose'
43
44     def validate_config(self):
45         head_pose_estimator_adapter_config = HeadPoseEstimatorAdapterConfig(
46             'HeadPoseEstimator_Config', on_extra_argument=HeadPoseEstimatorAdapterConfig.ERROR_ON_EXTRA_ARGUMENT)
47         head_pose_estimator_adapter_config.validate(self.launcher_config)
48
49     def configure(self):
50         """
51         Specifies parameters of config entry
52         """
53         self.angle_yaw = self.launcher_config['angle_yaw']
54         self.angle_pitch = self.launcher_config['angle_pitch']
55         self.angle_roll = self.launcher_config['angle_roll']
56
57     def process(self, raw, identifiers=None, frame_meta=None):
58         """
59         Args:
60             identifiers: list of input data identifiers
61             raw: output of model
62             frame_meta: list of meta information about each frame
63         Returns:
64                 list of ContainerPrediction objects
65         """
66         result = []
67         raw_output = self._extract_predictions(raw, frame_meta)
68         for identifier, yaw, pitch, roll in zip(
69                 identifiers,
70                 raw_output[self.angle_yaw],
71                 raw_output[self.angle_pitch],
72                 raw_output[self.angle_roll]
73         ):
74             prediction = ContainerPrediction({'angle_yaw': RegressionPrediction(identifier, yaw[0]),
75                                               'angle_pitch': RegressionPrediction(identifier, pitch[0]),
76                                               'angle_roll': RegressionPrediction(identifier, roll[0])})
77             result.append(prediction)
78
79         return result
80
81
82 class VehicleAttributesRecognitionAdapterConfig(ConfigValidator):
83     type = StringField()
84     color_out = StringField()
85     type_out = StringField()
86
87
88 class VehicleAttributesRecognitionAdapter(Adapter):
89     __provider__ = 'vehicle_attributes'
90
91     def validate_config(self):
92         attributes_recognition_adapter_config = VehicleAttributesRecognitionAdapterConfig(
93             'VehicleAttributesRecognition_Config',
94             on_extra_argument=VehicleAttributesRecognitionAdapterConfig.ERROR_ON_EXTRA_ARGUMENT)
95         attributes_recognition_adapter_config.validate(self.launcher_config)
96
97     def configure(self):
98         """
99         Specifies parameters of config entry
100         """
101         self.color_out = self.launcher_config['color_out']
102         self.type_out = self.launcher_config['type_out']
103
104     def process(self, raw, identifiers=None, frame_meta=None):
105         res = []
106         raw_output = self._extract_predictions(raw, frame_meta)
107         for identifier, colors, types in zip(identifiers, raw_output[self.color_out], raw_output[self.type_out]):
108             res.append(ContainerPrediction({'color': ClassificationPrediction(identifier, colors.reshape(-1)),
109                                             'type': ClassificationPrediction(identifier, types.reshape(-1))}))
110         return res
111
112
113 class AgeGenderAdapterConfig(ConfigValidator):
114     type = StringField()
115     age_out = StringField()
116     gender_out = StringField()
117
118
119 class AgeGenderAdapter(Adapter):
120     __provider__ = 'age_gender'
121
122     def configure(self):
123         self.age_out = self.launcher_config['age_out']
124         self.gender_out = self.launcher_config['gender_out']
125
126     def validate_config(self):
127         age_gender_adapter_config = AgeGenderAdapterConfig(
128             'AgeGender_Config', on_extra_argument=AgeGenderAdapterConfig.ERROR_ON_EXTRA_ARGUMENT)
129         age_gender_adapter_config.validate(self.launcher_config)
130
131     @staticmethod
132     def get_age_scores(age):
133         age_scores = np.zeros(4)
134         if age < 19:
135             age_scores[0] = 1
136             return age_scores
137         if age < 36:
138             age_scores[1] = 1
139             return age_scores
140         if age < 66:
141             age_scores[2] = 1
142             return age_scores
143         age_scores[3] = 1
144         return age_scores
145
146     def process(self, raw, identifiers=None, frame_meta=None):
147         result = []
148         raw_output = self._extract_predictions(raw, frame_meta)
149         for identifier, age, gender in zip(identifiers, raw_output[self.age_out], raw_output[self.gender_out]):
150             gender = gender.reshape(-1)
151             age = age.reshape(-1)[0]*100
152             gender_rep = ClassificationPrediction(identifier, gender)
153             age_class_rep = ClassificationPrediction(identifier, self.get_age_scores(age))
154             age_error_rep = RegressionPrediction(identifier, age)
155             result.append(ContainerPrediction({'gender': gender_rep, 'age_classification': age_class_rep,
156                                                'age_error': age_error_rep}))
157         return result
158
159
160 class LandmarksRegressionAdapter(Adapter):
161     __provider__ = 'landmarks_regression'
162
163     def process(self, raw, identifiers=None, frame_meta=None):
164         res = []
165         raw_output = self._extract_predictions(raw, frame_meta)
166         for identifier, values in zip(identifiers, raw_output[self.output_blob]):
167             x_values, y_values = values[::2], values[1::2]
168             res.append(FacialLandmarksPrediction(identifier, x_values.reshape(-1), y_values.reshape(-1)))
169         return res
170
171
172 class PersonAttributesConfig(ConfigValidator):
173     attributes_recognition_out = StringField(optional=True)
174
175
176 class PersonAttributesAdapter(Adapter):
177     __provider__ = 'person_attributes'
178
179     def validate_config(self):
180         person_attributes_adapter_config = PersonAttributesConfig(
181             'PersonAttributes_Config',
182             PersonAttributesConfig.IGNORE_ON_EXTRA_ARGUMENT
183         )
184         person_attributes_adapter_config.validate(self.launcher_config)
185
186     def configure(self):
187         self.attributes_recognition_out = self.launcher_config.get('attributes_recognition_out', self.output_blob)
188
189     def process(self, raw, identifiers=None, frame_meta=None):
190         result = []
191         raw_output = self._extract_predictions(raw, frame_meta)
192         for identifier, multi_label in zip(identifiers, raw_output[self.attributes_recognition_out or self.output_blob]):
193             multi_label[multi_label > 0.5] = 1.
194             multi_label[multi_label <= 0.5] = 0.
195
196             result.append(MultiLabelRecognitionPrediction(identifier, multi_label.reshape(-1)))
197
198         return result
199
200
201 class GazeEstimationAdapter(Adapter):
202     __provider__ = 'gaze_estimation'
203
204     def process(self, raw, identifiers=None, frame_meta=None):
205         result = []
206         raw_output = self._extract_predictions(raw, frame_meta)
207         for identifier, output in zip(identifiers, raw_output[self.output_blob]):
208             result.append(GazeVectorPrediction(identifier, output))
209
210         return result