return color
-def get_pydot_graph(caffe_net, rankdir, label_edges=True):
+def get_pydot_graph(caffe_net, rankdir, label_edges=True, phase=None):
"""Create a data structure which represents the `caffe_net`.
Parameters
Direction of graph layout.
label_edges : boolean, optional
Label the edges (default is True).
+ phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
+ Include layers from this network phase. If None, include all layers.
+ (the default is None)
Returns
-------
pydot_nodes = {}
pydot_edges = []
for layer in caffe_net.layer:
+ if phase is not None:
+ included = False
+ if len(layer.include) == 0:
+ included = True
+ if len(layer.include) > 0 and len(layer.exclude) > 0:
+ raise ValueError('layer ' + layer.name + ' has both include '
+ 'and exclude specified.')
+ for layer_phase in layer.include:
+ included = included or layer_phase.phase == phase
+ for layer_phase in layer.exclude:
+ included = included and not layer_phase.phase == phase
+ if not included:
+ continue
node_label = get_layer_label(layer, rankdir)
node_name = "%s_%s" % (layer.name, layer.type)
if (len(layer.bottom) == 1 and len(layer.top) == 1 and
return pydot_graph
-def draw_net(caffe_net, rankdir, ext='png'):
+def draw_net(caffe_net, rankdir, ext='png', phase=None):
"""Draws a caffe net and returns the image string encoded using the given
extension.
caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer.
ext : string, optional
The image extension (the default is 'png').
+ phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
+ Include layers from this network phase. If None, include all layers.
+ (the default is None)
Returns
-------
string :
Postscript representation of the graph.
"""
- return get_pydot_graph(caffe_net, rankdir).create(format=ext)
+ return get_pydot_graph(caffe_net, rankdir, phase=phase).create(format=ext)
-def draw_net_to_file(caffe_net, filename, rankdir='LR'):
+def draw_net_to_file(caffe_net, filename, rankdir='LR', phase=None):
"""Draws a caffe net, and saves it to file using the format given as the
file extension. Use '.raw' to output raw text that you can manually feed
to graphviz to draw graphs.
The path to a file where the networks visualization will be stored.
rankdir : {'LR', 'TB', 'BT'}
Direction of graph layout.
+ phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
+ Include layers from this network phase. If None, include all layers.
+ (the default is None)
"""
ext = filename[filename.rfind('.')+1:]
with open(filename, 'wb') as fid:
- fid.write(draw_net(caffe_net, rankdir, ext))
+ fid.write(draw_net(caffe_net, rankdir, ext, phase))
'http://www.graphviz.org/doc/info/'
'attrs.html#k:rankdir'),
default='LR')
+ parser.add_argument('--phase',
+ help=('Which network phase to draw: can be TRAIN, '
+ 'TEST, or ALL. If ALL, then all layers are drawn '
+ 'regardless of phase.'),
+ default="ALL")
args = parser.parse_args()
return args
net = caffe_pb2.NetParameter()
text_format.Merge(open(args.input_net_proto_file).read(), net)
print('Drawing net to %s' % args.output_image_file)
- caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir)
+ phase=None;
+ if args.phase == "TRAIN":
+ phase = caffe.TRAIN
+ elif args.phase == "TEST":
+ phase = caffe.TEST
+ elif args.phase != "ALL":
+ raise ValueError("Unknown phase: " + args.phase)
+ caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir,
+ phase)
if __name__ == '__main__':