2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from .dependency import ClassProvider
22 from .logging import print_info
25 class ProgressReporter(ClassProvider):
26 __provider_type__ = 'progress_reporter'
28 def __init__(self, dataset_size=None):
30 self.dataset_size = None
31 self.start_time = None
33 if dataset_size is not None:
34 self.reset(dataset_size)
37 def finish(self, objects_processed=True):
39 if not objects_processed:
42 process_time = time.time() - self.start_time
43 print_info('{} objects processed in {:.3f} seconds'.format(self.dataset_size, process_time))
47 return (self.current / self.dataset_size) * 100 if self.dataset_size else 0
49 def reset(self, dataset_size):
51 self.finish(objects_processed=False)
54 self.dataset_size = dataset_size
55 self.start_time = time.time()
59 class PrintProgressReporter(ProgressReporter):
60 __provider__ = 'print'
62 def __init__(self, dataset_size=None, print_interval=1000):
63 super().__init__(dataset_size)
64 self.print_interval = print_interval
66 def reset(self, dataset_size):
67 self.dataset_size = dataset_size
68 print_info('Total dataset size: {}'.format(dataset_size))
69 self.start_time = time.time()
70 self.prev_time = self.start_time
72 def update(self, batch_id, batch_size):
73 self.current += batch_size
74 if (batch_id + 1) % self.print_interval != 0:
78 batch_time = now - self.prev_time
81 print_info('{} / {} processed in {:.3f}s'.format((batch_id + 1) * batch_size, self.dataset_size, batch_time))
84 class TQDMReporter(ProgressReporter):
87 def update(self, _batch_id, batch_size):
88 self.current += batch_size
89 self.tqdm.update(batch_size)
91 def finish(self, objects_processed=True):
93 super().finish(objects_processed)
95 def reset(self, dataset_size):
96 super().reset(dataset_size)
98 total=self.dataset_size, unit='frames', leave=False,
99 bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'