split drawnet into module code and script
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Mon, 19 May 2014 00:14:53 +0000 (17:14 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Tue, 20 May 2014 06:55:22 +0000 (23:55 -0700)
Don't run scripts in the module dir to avoid import collisions between
io and caffe.io.

python/caffe/draw.py [moved from python/caffe/drawnet.py with 87% similarity]
python/draw_net.py [new file with mode: 0755]

similarity index 87%
rename from python/caffe/drawnet.py
rename to python/caffe/draw.py
index ff18ecf..f8631cf 100644 (file)
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
 """
 Caffe network visualization: draw the NetParameter protobuffer.
 
@@ -10,8 +9,6 @@ Caffe.
 from caffe.proto import caffe_pb2
 from google.protobuf import text_format
 import pydot
-import os
-import sys
 
 # Internal layer and blob styles.
 LAYER_STYLE = {'shape': 'record', 'fillcolor': '#6495ED',
@@ -77,15 +74,3 @@ def draw_net_to_file(caffe_net, filename):
   ext = filename[filename.rfind('.')+1:]
   with open(filename, 'wb') as fid:
     fid.write(draw_net(caffe_net, ext))
-
-if __name__ == '__main__':
-  if len(sys.argv) != 3:
-    print 'Usage: %s input_net_proto_file output_image_file' % \
-        os.path.basename(sys.argv[0])
-  else:
-    net = caffe_pb2.NetParameter()
-    text_format.Merge(open(sys.argv[1]).read(), net)
-    print 'Drawing net to %s' % sys.argv[2]
-    draw_net_to_file(net, sys.argv[2])
-
-
diff --git a/python/draw_net.py b/python/draw_net.py
new file mode 100755 (executable)
index 0000000..cbea5d9
--- /dev/null
@@ -0,0 +1,25 @@
+#!/usr/bin/env python
+"""
+Draw a graph of the net architecture.
+"""
+import os
+from google.protobuf import text_format
+
+import caffe
+from caffe.proto import caffe_pb2
+
+
+def main(argv):
+    if len(argv) != 3:
+        print 'Usage: %s input_net_proto_file output_image_file' % \
+                os.path.basename(sys.argv[0])
+    else:
+        net = caffe_pb2.NetParameter()
+        text_format.Merge(open(sys.argv[1]).read(), net)
+        print 'Drawing net to %s' % sys.argv[2]
+        draw_net_to_file(net, sys.argv[2])
+
+
+if __name__ == '__main__':
+    import sys
+    main(sys.argv)