Expose LayerFactory::LayerTypeList in pycaffe
authorLuke Yeager <lyeager@nvidia.com>
Fri, 14 Aug 2015 23:53:39 +0000 (16:53 -0700)
committerLuke Yeager <lyeager@nvidia.com>
Thu, 20 Aug 2015 17:06:22 +0000 (10:06 -0700)
Useful for validating NetParameters without crashing on SIGABRT

include/caffe/layer_factory.hpp
python/caffe/__init__.py
python/caffe/_caffe.cpp
python/caffe/test/test_layer_type_list.py [new file with mode: 0644]

index 32e849d..2c2fde4 100644 (file)
@@ -41,6 +41,7 @@
 
 #include <map>
 #include <string>
+#include <vector>
 
 #include "caffe/common.hpp"
 #include "caffe/proto/caffe.pb.h"
@@ -77,26 +78,36 @@ class LayerRegistry {
     const string& type = param.type();
     CreatorRegistry& registry = Registry();
     CHECK_EQ(registry.count(type), 1) << "Unknown layer type: " << type
-        << " (known types: " << LayerTypeList() << ")";
+        << " (known types: " << LayerTypeListString() << ")";
     return registry[type](param);
   }
 
+  static vector<string> LayerTypeList() {
+    CreatorRegistry& registry = Registry();
+    vector<string> layer_types;
+    for (typename CreatorRegistry::iterator iter = registry.begin();
+         iter != registry.end(); ++iter) {
+      layer_types.push_back(iter->first);
+    }
+    return layer_types;
+  }
+
  private:
   // Layer registry should never be instantiated - everything is done with its
   // static variables.
   LayerRegistry() {}
 
-  static string LayerTypeList() {
-    CreatorRegistry& registry = Registry();
-    string layer_types;
-    for (typename CreatorRegistry::iterator iter = registry.begin();
-         iter != registry.end(); ++iter) {
-      if (iter != registry.begin()) {
-        layer_types += ", ";
+  static string LayerTypeListString() {
+    vector<string> layer_types = LayerTypeList();
+    string layer_types_str;
+    for (vector<string>::iterator iter = layer_types.begin();
+         iter != layer_types.end(); ++iter) {
+      if (iter != layer_types.begin()) {
+        layer_types_str += ", ";
       }
-      layer_types += iter->first;
+      layer_types_str += *iter;
     }
-    return layer_types;
+    return layer_types_str;
   }
 };
 
index 1b2da51..6cc44e7 100644 (file)
@@ -1,5 +1,5 @@
 from .pycaffe import Net, SGDSolver
-from ._caffe import set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver
+from ._caffe import set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list
 from .proto.caffe_pb2 import TRAIN, TEST
 from .classifier import Classifier
 from .detector import Detector
index bb5130f..f9b2dba 100644 (file)
@@ -200,6 +200,8 @@ BOOST_PYTHON_MODULE(_caffe) {
   bp::def("set_mode_gpu", &set_mode_gpu);
   bp::def("set_device", &Caffe::SetDevice);
 
+  bp::def("layer_type_list", &LayerRegistry<Dtype>::LayerTypeList);
+
   bp::class_<Net<Dtype>, shared_ptr<Net<Dtype> >, boost::noncopyable >("Net",
     bp::no_init)
     .def("__init__", bp::make_constructor(&Net_Init))
diff --git a/python/caffe/test/test_layer_type_list.py b/python/caffe/test/test_layer_type_list.py
new file mode 100644 (file)
index 0000000..7edc80d
--- /dev/null
@@ -0,0 +1,10 @@
+import unittest
+
+import caffe
+
+class TestLayerTypeList(unittest.TestCase):
+
+    def test_standard_types(self):
+        for type_name in ['Data', 'Convolution', 'InnerProduct']:
+            self.assertIn(type_name, caffe.layer_type_list(),
+                    '%s not in layer_type_list()' % type_name)