Publishing 2019 R2 content (#223)
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / annotation_converters / convert.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16 import warnings
17 import json
18 from pathlib import Path
19 from argparse import ArgumentParser
20 from functools import partial
21
22 import numpy as np
23
24 from ..utils import get_path
25 from ..representation import ReIdentificationClassificationAnnotation
26 from .format_converter import BaseFormatConverter
27
28
29 def build_argparser():
30     parser = ArgumentParser(
31         description="Converts annotation form a arbitrary format to accuracy-checker specific format", add_help=False
32     )
33     parser.add_argument(
34         "converter",
35         help="Specific converter to run",
36         choices=list(BaseFormatConverter.providers.keys())
37     )
38     parser.add_argument(
39         "-o", "--output_dir",
40         help="Directory to save converted annotation and meta info",
41         required=False,
42         type=partial(get_path, is_directory=True)
43     )
44     parser.add_argument("-m", "--meta_name", help="Meta info file name", required=False)
45     parser.add_argument("-a", "--annotation_name", help="Annotation file name", required=False)
46     parser.add_argument("-ss", "--subsample", help="Dataset subsample size", required=False)
47     parser.add_argument("--subsample_seed", help="Seed for generation dataset subsample", type=int, required=False)
48
49     return parser
50
51
52 def make_subset(annotation, size, seed=666):
53     def make_subset_pairwise(annotation, size):
54         def get_pairs(pairs_list):
55             pairs_set = set()
56             for identifier in pairs_list:
57                 next_annotation = next(
58                     pair_annotation for pair_annotation in annotation if pair_annotation.identifier == identifier
59                 )
60                 positive_pairs = get_pairs(next_annotation.positive_pairs)
61                 negative_pairs = get_pairs(next_annotation.negative_pairs)
62                 pairs_set.add(next_annotation)
63                 pairs_set.update(positive_pairs)
64                 pairs_set.update(negative_pairs)
65             return pairs_set
66
67         subsample_set = set()
68         while len(subsample_set) < size:
69             ann_ind = np.random.choice(len(annotation), 1)
70             annotation_for_subset = annotation[ann_ind[0]]
71             positive_pairs = annotation_for_subset.positive_pairs
72             negative_pairs = annotation_for_subset.negative_pairs
73             if len(positive_pairs) + len(negative_pairs) == 0:
74                 continue
75             updated_pairs = set()
76             updated_pairs.add(annotation_for_subset)
77             updated_pairs.update(get_pairs(positive_pairs))
78             updated_pairs.update(get_pairs(negative_pairs))
79             subsample_set.update(updated_pairs)
80         return list(subsample_set)
81
82     np.random.seed(seed)
83     dataset_size = len(annotation)
84     if dataset_size < size:
85         warnings.warn('Dataset size {} less than subset size {}'.format(dataset_size, size))
86         return annotation
87     if isinstance(annotation[-1], ReIdentificationClassificationAnnotation):
88         return make_subset_pairwise(annotation, size)
89
90
91     return list(np.random.choice(annotation, size=size, replace=False))
92
93
94 def main():
95     main_argparser = build_argparser()
96     args, _ = main_argparser.parse_known_args()
97     converter, converter_argparser, converter_args = get_converter_arguments(args)
98
99     main_argparser = ArgumentParser(parents=[main_argparser, converter_argparser])
100     args = main_argparser.parse_args()
101
102     converter = configure_converter(converter_args, args, converter)
103     out_dir = args.output_dir or Path.cwd()
104
105     result, meta = converter.convert()
106
107     subsample = args.subsample
108     if subsample:
109         if subsample.endswith('%'):
110             subsample_ratio = float(subsample[:-1]) / 100
111             subsample_size = int(len(result) * subsample_ratio)
112         else:
113             subsample_size = int(args.subsample)
114
115         result = make_subset(result, subsample_size)
116
117     converter_name = converter.get_name()
118     annotation_name = args.annotation_name or "{}.pickle".format(converter_name)
119     meta_name = args.meta_name or "{}.json".format(converter_name)
120
121     annotation_file = out_dir / annotation_name
122     meta_file = out_dir / meta_name
123
124     save_annotation(result, meta, annotation_file, meta_file)
125
126
127 def save_annotation(annotation, meta, annotation_file, meta_file):
128     if annotation_file:
129         with annotation_file.open('wb') as file:
130             for representation in annotation:
131                 representation.dump(file)
132     if meta_file and meta:
133         with meta_file.open('wt') as file:
134             json.dump(meta, file)
135
136
137 def configure_converter(converter_options, args, converter):
138     args_dict, converter_options_dict = vars(args), vars(converter_options)
139     converter_config = {
140         option_name: option_value for option_name, option_value in args_dict.items()
141         if option_name in converter_options_dict and option_value is not None
142     }
143     converter_config['converter'] = args.converter
144     converter.config = converter_config
145     converter.validate_config()
146     converter.configure()
147
148     return converter
149
150
151 def get_converter_arguments(arguments):
152     converter = BaseFormatConverter.provide(arguments.converter)
153     converter_argparser = converter.get_argparser()
154     converter_options, _ = converter_argparser.parse_known_args()
155     return converter, converter_argparser, converter_options