Update the drawnet.py to reflect the recent revised net definition.
authorZhiHeng NIU <niuzhiheng@gmail.com>
Fri, 25 Apr 2014 09:43:56 +0000 (17:43 +0800)
committerZhiHeng NIU <niuzhiheng@gmail.com>
Fri, 25 Apr 2014 09:43:56 +0000 (17:43 +0800)
python/caffe/drawnet.py

index 8ff0d83..de5a876 100644 (file)
@@ -15,14 +15,21 @@ NEURON_LAYER_STYLE = {'shape': 'record', 'fillcolor': '#90EE90',
          'style': 'filled'}
 BLOB_STYLE = {'shape': 'octagon', 'fillcolor': '#F0E68C',
         'style': 'filled'}
+def get_enum_name_by_value():
+  desc = caffe_pb2.LayerParameter.LayerType.DESCRIPTOR
+  d = {}
+  for k,v in desc.values_by_name.items():
+    d[v.number] = k
+  return d
 
 def get_pydot_graph(caffe_net):
-  pydot_graph = pydot.Dot(caffe_net.name, graph_type='digraph')
+  pydot_graph = pydot.Dot(caffe_net.name, graph_type='digraph', rankdir="BT")
   pydot_nodes = {}
   pydot_edges = []
+  d = get_enum_name_by_value()
   for layer in caffe_net.layers:
-    name = layer.layer.name
-    layertype = layer.layer.type
+    name = layer.name
+    layertype = d[layer.type]
     if (len(layer.bottom) == 1 and len(layer.top) == 1 and
         layer.bottom[0] == layer.top[0]):
       # We have an in-place neuron layer.
@@ -63,7 +70,7 @@ def draw_net_to_file(caffe_net, filename):
   to graphviz to draw graphs.
   """
   ext = filename[filename.rfind('.')+1:]
-  with open(filename, 'w') as fid:
+  with open(filename, 'wb') as fid:
     fid.write(draw_net(caffe_net, ext))
 
 if __name__ == '__main__':