Python 2/3 compatible download_model_binary.py
[platform/upstream/caffeonacl.git] / scripts / download_model_binary.py
1 #!/usr/bin/env python
2 import os
3 import sys
4 import time
5 import yaml
6 import hashlib
7 import argparse
8
9 from six.moves import urllib
10
11 required_keys = ['caffemodel', 'caffemodel_url', 'sha1']
12
13
14 def reporthook(count, block_size, total_size):
15     """
16     From http://blog.moleculea.com/2012/10/04/urlretrieve-progres-indicator/
17     """
18     global start_time
19     if count == 0:
20         start_time = time.time()
21         return
22     duration = (time.time() - start_time) or 0.01
23     progress_size = int(count * block_size)
24     speed = int(progress_size / (1024 * duration))
25     percent = int(count * block_size * 100 / total_size)
26     sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
27                     (percent, progress_size / (1024 * 1024), speed, duration))
28     sys.stdout.flush()
29
30
31 def parse_readme_frontmatter(dirname):
32     readme_filename = os.path.join(dirname, 'readme.md')
33     with open(readme_filename) as f:
34         lines = [line.strip() for line in f.readlines()]
35     top = lines.index('---')
36     bottom = lines.index('---', top + 1)
37     frontmatter = yaml.load('\n'.join(lines[top + 1:bottom]))
38     assert all(key in frontmatter for key in required_keys)
39     return dirname, frontmatter
40
41
42 def valid_dirname(dirname):
43     try:
44         return parse_readme_frontmatter(dirname)
45     except Exception as e:
46         print('ERROR: {}'.format(e))
47         raise argparse.ArgumentTypeError(
48             'Must be valid Caffe model directory with a correct readme.md')
49
50
51 if __name__ == '__main__':
52     parser = argparse.ArgumentParser(
53         description='Download trained model binary.')
54     parser.add_argument('dirname', type=valid_dirname)
55     args = parser.parse_args()
56
57     # A tiny hack: the dirname validator also returns readme YAML frontmatter.
58     dirname = args.dirname[0]
59     frontmatter = args.dirname[1]
60     model_filename = os.path.join(dirname, frontmatter['caffemodel'])
61
62     # Closure-d function for checking SHA1.
63     def model_checks_out(filename=model_filename, sha1=frontmatter['sha1']):
64         with open(filename, 'rb') as f:
65             return hashlib.sha1(f.read()).hexdigest() == sha1
66
67     # Check if model exists.
68     if os.path.exists(model_filename) and model_checks_out():
69         print("Model already exists.")
70         sys.exit(0)
71
72     # Download and verify model.
73     urllib.request.urlretrieve(
74         frontmatter['caffemodel_url'], model_filename, reporthook)
75     if not model_checks_out():
76         print('ERROR: model did not download correctly! Run this again.')
77         sys.exit(1)