Add phase support for draw net
authorCarl Doersch <cdoersch@cs.cmu.edu>
Sun, 3 Jul 2016 19:32:19 +0000 (12:32 -0700)
committerCarl Doersch <cdoersch@cs.cmu.edu>
Tue, 5 Jul 2016 23:06:58 +0000 (16:06 -0700)
python/caffe/draw.py
python/draw_net.py

index 61205ca..9eecf6d 100644 (file)
@@ -127,7 +127,7 @@ def choose_color_by_layertype(layertype):
     return color
 
 
-def get_pydot_graph(caffe_net, rankdir, label_edges=True):
+def get_pydot_graph(caffe_net, rankdir, label_edges=True, phase=None):
     """Create a data structure which represents the `caffe_net`.
 
     Parameters
@@ -137,6 +137,9 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
         Direction of graph layout.
     label_edges : boolean, optional
         Label the edges (default is True).
+    phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
+        Include layers from this network phase.  If None, include all layers.
+        (the default is None)
 
     Returns
     -------
@@ -148,6 +151,19 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
     pydot_nodes = {}
     pydot_edges = []
     for layer in caffe_net.layer:
+        if phase is not None:
+          included = False
+          if len(layer.include) == 0:
+            included = True
+          if len(layer.include) > 0 and len(layer.exclude) > 0:
+            raise ValueError('layer ' + layer.name + ' has both include '
+                             'and exclude specified.')
+          for layer_phase in layer.include:
+            included = included or layer_phase.phase == phase
+          for layer_phase in layer.exclude:
+            included = included and not layer_phase.phase == phase
+          if not included:
+            continue
         node_label = get_layer_label(layer, rankdir)
         node_name = "%s_%s" % (layer.name, layer.type)
         if (len(layer.bottom) == 1 and len(layer.top) == 1 and
@@ -186,7 +202,7 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
     return pydot_graph
 
 
-def draw_net(caffe_net, rankdir, ext='png'):
+def draw_net(caffe_net, rankdir, ext='png', phase=None):
     """Draws a caffe net and returns the image string encoded using the given
     extension.
 
@@ -195,16 +211,19 @@ def draw_net(caffe_net, rankdir, ext='png'):
     caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer.
     ext : string, optional
         The image extension (the default is 'png').
+    phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
+        Include layers from this network phase.  If None, include all layers.
+        (the default is None)
 
     Returns
     -------
     string :
         Postscript representation of the graph.
     """
-    return get_pydot_graph(caffe_net, rankdir).create(format=ext)
+    return get_pydot_graph(caffe_net, rankdir, phase=phase).create(format=ext)
 
 
-def draw_net_to_file(caffe_net, filename, rankdir='LR'):
+def draw_net_to_file(caffe_net, filename, rankdir='LR', phase=None):
     """Draws a caffe net, and saves it to file using the format given as the
     file extension. Use '.raw' to output raw text that you can manually feed
     to graphviz to draw graphs.
@@ -216,7 +235,10 @@ def draw_net_to_file(caffe_net, filename, rankdir='LR'):
         The path to a file where the networks visualization will be stored.
     rankdir : {'LR', 'TB', 'BT'}
         Direction of graph layout.
+    phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
+        Include layers from this network phase.  If None, include all layers.
+        (the default is None)
     """
     ext = filename[filename.rfind('.')+1:]
     with open(filename, 'wb') as fid:
-        fid.write(draw_net(caffe_net, rankdir, ext))
+        fid.write(draw_net(caffe_net, rankdir, ext, phase))
index ec76a74..dfe70d2 100755 (executable)
@@ -28,6 +28,11 @@ def parse_args():
                               'http://www.graphviz.org/doc/info/'
                               'attrs.html#k:rankdir'),
                         default='LR')
+    parser.add_argument('--phase',
+                        help=('Which network phase to draw: can be TRAIN, '
+                              'TEST, or ALL.  If ALL, then all layers are drawn '
+                              'regardless of phase.'),
+                        default="ALL")
 
     args = parser.parse_args()
     return args
@@ -38,7 +43,15 @@ def main():
     net = caffe_pb2.NetParameter()
     text_format.Merge(open(args.input_net_proto_file).read(), net)
     print('Drawing net to %s' % args.output_image_file)
-    caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir)
+    phase=None;
+    if args.phase == "TRAIN":
+        phase = caffe.TRAIN
+    elif args.phase == "TEST":
+        phase = caffe.TEST
+    elif args.phase != "ALL":
+        raise ValueError("Unknown phase: " + args.phase)
+    caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir,
+                                phase)
 
 
 if __name__ == '__main__':