9 from six.moves import urllib
11 required_keys = ['caffemodel', 'caffemodel_url', 'sha1']
14 def reporthook(count, block_size, total_size):
16 From http://blog.moleculea.com/2012/10/04/urlretrieve-progres-indicator/
20 start_time = time.time()
22 duration = (time.time() - start_time) or 0.01
23 progress_size = int(count * block_size)
24 speed = int(progress_size / (1024 * duration))
25 percent = int(count * block_size * 100 / total_size)
26 sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
27 (percent, progress_size / (1024 * 1024), speed, duration))
31 def parse_readme_frontmatter(dirname):
32 readme_filename = os.path.join(dirname, 'readme.md')
33 with open(readme_filename) as f:
34 lines = [line.strip() for line in f.readlines()]
35 top = lines.index('---')
36 bottom = lines.index('---', top + 1)
37 frontmatter = yaml.load('\n'.join(lines[top + 1:bottom]))
38 assert all(key in frontmatter for key in required_keys)
39 return dirname, frontmatter
42 def valid_dirname(dirname):
44 return parse_readme_frontmatter(dirname)
45 except Exception as e:
46 print('ERROR: {}'.format(e))
47 raise argparse.ArgumentTypeError(
48 'Must be valid Caffe model directory with a correct readme.md')
51 if __name__ == '__main__':
52 parser = argparse.ArgumentParser(
53 description='Download trained model binary.')
54 parser.add_argument('dirname', type=valid_dirname)
55 args = parser.parse_args()
57 # A tiny hack: the dirname validator also returns readme YAML frontmatter.
58 dirname = args.dirname[0]
59 frontmatter = args.dirname[1]
60 model_filename = os.path.join(dirname, frontmatter['caffemodel'])
62 # Closure-d function for checking SHA1.
63 def model_checks_out(filename=model_filename, sha1=frontmatter['sha1']):
64 with open(filename, 'rb') as f:
65 return hashlib.sha1(f.read()).hexdigest() == sha1
67 # Check if model exists.
68 if os.path.exists(model_filename) and model_checks_out():
69 print("Model already exists.")
72 # Download and verify model.
73 urllib.request.urlretrieve(
74 frontmatter['caffemodel_url'], model_filename, reporthook)
75 if not model_checks_out():
76 print('ERROR: model did not download correctly! Run this again.')