Fix Python net drawing script
authorNico Galoppo <nico.galoppo@intel.com>
Mon, 21 Nov 2016 19:03:52 +0000 (11:03 -0800)
committerNico Galoppo <nico.galoppo@intel.com>
Mon, 21 Nov 2016 19:03:52 +0000 (11:03 -0800)
python/caffe/draw.py

index 9eecf6d..e4fd7aa 100644 (file)
@@ -104,11 +104,11 @@ def get_layer_label(layer, rankdir):
                       pooling_types_dict[layer.pooling_param.pool],
                       layer.type,
                       separator,
-                      layer.pooling_param.kernel_size,
+                      layer.pooling_param.kernel_size[0] if len(layer.pooling_param.kernel_size._values) else 1,
                       separator,
-                      layer.pooling_param.stride,
+                      layer.pooling_param.stride[0] if len(layer.pooling_param.stride._values) else 1,
                       separator,
-                      layer.pooling_param.pad)
+                      layer.pooling_param.pad[0] if len(layer.pooling_param.pad._values) else 0)
     else:
         node_label = '"%s%s(%s)"' % (layer.name, separator, layer.type)
     return node_label