Update release_notes.md
[platform/upstream/caffeonacl.git] / python / draw_net.py
1 #!/usr/bin/env python
2 """
3 Draw a graph of the net architecture.
4 """
5 from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
6 from google.protobuf import text_format
7
8 import caffe
9 import caffe.draw
10 from caffe.proto import caffe_pb2
11
12
13 def parse_args():
14     """Parse input arguments
15     """
16
17     parser = ArgumentParser(description=__doc__,
18                             formatter_class=ArgumentDefaultsHelpFormatter)
19
20     parser.add_argument('input_net_proto_file',
21                         help='Input network prototxt file')
22     parser.add_argument('output_image_file',
23                         help='Output image file')
24     parser.add_argument('--rankdir',
25                         help=('One of TB (top-bottom, i.e., vertical), '
26                               'RL (right-left, i.e., horizontal), or another '
27                               'valid dot option; see '
28                               'http://www.graphviz.org/doc/info/'
29                               'attrs.html#k:rankdir'),
30                         default='LR')
31     parser.add_argument('--phase',
32                         help=('Which network phase to draw: can be TRAIN, '
33                               'TEST, or ALL.  If ALL, then all layers are drawn '
34                               'regardless of phase.'),
35                         default="ALL")
36
37     args = parser.parse_args()
38     return args
39
40
41 def main():
42     args = parse_args()
43     net = caffe_pb2.NetParameter()
44     text_format.Merge(open(args.input_net_proto_file).read(), net)
45     print('Drawing net to %s' % args.output_image_file)
46     phase=None;
47     if args.phase == "TRAIN":
48         phase = caffe.TRAIN
49     elif args.phase == "TEST":
50         phase = caffe.TEST
51     elif args.phase != "ALL":
52         raise ValueError("Unknown phase: " + args.phase)
53     caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir,
54                                 phase)
55
56
57 if __name__ == '__main__':
58     main()