2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from .dependency import ClassProvider
22 from .logging import print_info
25 class ProgressReporter(ClassProvider):
26 __provider_type__ = 'progress_reporter'
28 def __init__(self, dataset_size=None):
30 self.dataset_size = None
31 self.start_time = None
33 if dataset_size is not None:
34 self.reset(dataset_size)
36 def finish(self, objects_processed=True):
38 if not objects_processed:
41 process_time = time.time() - self.start_time
42 print_info('{} objects processed in {:.3f} seconds'.format(self.dataset_size, process_time))
44 def reset(self, dataset_size):
46 self.finish(objects_processed=False)
48 self.dataset_size = dataset_size
49 self.start_time = time.time()
53 class PrintProgressReporter(ProgressReporter):
54 __provider__ = 'print'
56 def __init__(self, dataset_size=None, print_interval=1000):
57 super().__init__(dataset_size)
58 self.print_interval = print_interval
60 def reset(self, dataset_size):
61 self.dataset_size = dataset_size
62 print_info('Total dataset size: {}'.format(dataset_size))
63 self.start_time = time.time()
64 self.prev_time = self.start_time
66 def update(self, batch_id, batch_size):
67 if (batch_id + 1) % self.print_interval != 0:
71 batch_time = now - self.prev_time
74 print_info('{} / {} processed in {:.3f}s'.format((batch_id + 1) * batch_size, self.dataset_size, batch_time))
77 class TQDMReporter(ProgressReporter):
80 def update(self, _batch_id, batch_size):
81 self.tqdm.update(batch_size)
83 def finish(self, objects_processed=True):
85 super().finish(objects_processed)
87 def reset(self, dataset_size):
88 super().reset(dataset_size)
90 total=self.dataset_size, unit='frames', leave=False,
91 bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'